Summary
On an Mac M5, mlx.addmm and A @ B are consistently ~1.1–1.2× slower than PyTorch’s addmm / matmul on MPS for 1280×1280 BF16 matrices. This GEMM shape is representative of Nano Chat–style transformer training, so the gap directly reduces end‑to‑end training throughput vs PyTorch.
It sped up after NA support but its still behind since official Neural Accelerator Support
Repro
Gist with script and logs: https://gist.github.com/Anemll/5420800c3d29c7fae18a2b9b10907b14
Key details:
Shape: (1280, 1280)
Dtype: bfloat16
Device: PyTorch mps, MLX Metal
Warmup: 30, Iterations: 1000
Ops tested:
addmm: beta * C + alpha * (A @ B)
matmul: A @ B
Sync:
PyTorch: torch.mps.synchronize()
MLX: mx.eval(result)
Results (MLX 0.31.0)
PyTorch MPS:
addmm: ~0.596 ms
matmul: ~0.549 ms
MLX:
addmm: ~0.651 ms
matmul: ~0.646 ms
Ratios:
addmm: 1.09× slower (MLX / PyTorch)
matmul: 1.18× slower
Motivation (Nano Chat)
Small LLM (“Nano Chat”) training is GEMM‑bound at these sizes; this 10–20% gap in BF16 GEMM on M‑series Macs translates almost directly into slower training steps vs PyTorch MPS. Closing this gap would make MLX more competitive as the default backend for local Nano Chat training.