-
I am stressing out the performance of matmul-like kernels on 1x TPUv4. However, I can't push the performance above 50% of the peak FLOPS (275 TFLOPS) as specified in the paper. The following code shows the measurement with import jax
import jax.numpy as jnp
device = jax.default_device()
print(f"Jax default backend: {jax.default_backend()}")
key = jax.random.key(0)
x_bf16 = jax.random.uniform(key, (2**16, 4096), dtype=jnp.bfloat16)
y_bf16 = jax.random.uniform(key, (4096, 4096), dtype=jnp.bfloat16)
dot = jax.jit(jax.lax.dot) # .lower(x_bf16, y_bf16).compile()
dot(x_bf16, y_bf16).block_until_ready()
%timeit dot(x_bf16, y_bf16).block_until_ready() Output:
which gives a throughput of only Is it due to the code not using the other processor on TPUv4? If it is, how should I make it to? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Are you missing a factor of 2 in your flop count? i.e., for an |
Beta Was this translation helpful? Give feedback.
Are you missing a factor of 2 in your flop count? i.e., for an
[M, N]
by[N, K]
matmul you needMNK
fused multiply-adds, which is 2MNK flops since each FMA is two ops.