Skip to content

Conversation

@dzoba
Copy link

@dzoba dzoba commented Nov 23, 2025

Proposed changes

Improve batched GEMM performance by enabling swizzle for operations with large tile grids.

Modified swizzle heuristic in steel_matmul_regular_axpby to use swizzle_log=1 when:

  • batch_size_out > 1 (batched operations)
  • tm >= 8 && tn >= 8 (large tile grids, i.e., M >= 512 and N >= 512 with 64x64 blocks)

This improves cache efficiency for batched matrix multiplications commonly used in LLM training (attention projections, FFN layers).

Benchmark Results

Tested on Apple Silicon M4 Max:

Shape (B, M, N, K) Before After Improvement
(16, 1024, 1024, 1024) float32 -7.6% vs PyTorch -0.3% +7.3 pp
(4, 1024, 1024, 4096) float32 -17.6% +8.1% +25.7 pp
(4, 1024, 4096, 1024) float32 -9.9% +20.3% +30.2 pp

Non-batched operations (batch=1) are unchanged as they don't benefit from swizzle.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Enable swizzle_log=1 for batched operations with large tile grids
(batch > 1, tm >= 8, tn >= 8) to improve cache efficiency.

Benchmarks show 7-30% improvement on common LLM training shapes:
- (16, 1024, 1024, 1024): -7.6% -> -0.3% vs PyTorch
- (4, 1024, 1024, 4096): -17.6% -> +8.1%
- (4, 1024, 4096, 1024): -9.9% -> +20.3%
@awni
Copy link
Member

awni commented Dec 1, 2025

@jagrit06 what do you think about this? Results look quite nice to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants