From 909bf3c628ebdc1874fead31a6a6e0c577f1d154 Mon Sep 17 00:00:00 2001 From: Chris Dzoba Date: Sun, 23 Nov 2025 15:05:54 -0500 Subject: [PATCH] Improve batched GEMM performance with swizzle tuning Enable swizzle_log=1 for batched operations with large tile grids (batch > 1, tm >= 8, tn >= 8) to improve cache efficiency. Benchmarks show 7-30% improvement on common LLM training shapes: - (16, 1024, 1024, 1024): -7.6% -> -0.3% vs PyTorch - (4, 1024, 1024, 4096): -17.6% -> +8.1% - (4, 1024, 4096, 1024): -9.9% -> +20.3% --- mlx/backend/metal/matmul.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index add11c1466..c3c7b8411c 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -454,7 +454,8 @@ void steel_matmul_regular_axpby( int tm = (M + bm - 1) / bm; // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); + // Use swizzle for batched operations with larger tile grids + int swizzle_log = (batch_size_out > 1 && tm >= 8 && tn >= 8) ? 1 : 0; // Prepare steel matmul params GEMMParams params{