Improve batched gemm swizzle #2824
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposed changes
Improve batched GEMM performance by enabling swizzle for operations with large tile grids.
Modified swizzle heuristic in
steel_matmul_regular_axpbyto useswizzle_log=1when:batch_size_out > 1(batched operations)tm >= 8 && tn >= 8(large tile grids, i.e., M >= 512 and N >= 512 with 64x64 blocks)This improves cache efficiency for batched matrix multiplications commonly used in LLM training (attention projections, FFN layers).
Benchmark Results
Tested on Apple Silicon M4 Max:
Non-batched operations (batch=1) are unchanged as they don't benefit from swizzle.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes