Add Metal (MPS) acceleration for Apple Silicon - 18x speedup#78
Open
imperatormk wants to merge 3 commits intoalxndrTL:mainfrom
Open
Add Metal (MPS) acceleration for Apple Silicon - 18x speedup#78imperatormk wants to merge 3 commits intoalxndrTL:mainfrom
imperatormk wants to merge 3 commits intoalxndrTL:mainfrom
Conversation
~18x speedup on M1/M2/M3/M4 Macs, making Mamba practical for training on Apple Silicon. Performance (M3 Max, batch=1024, seq_len=2, d_model=384): - PyTorch MPS baseline: ~120ms (57x slower than attention) - Metal fused kernels: ~7ms (3.2x slower than attention) Changes: - metal_pscan/: Native Metal compute shaders with PyTorch autograd integration - setup_metal.py: Build script for the Metal extension - mambapy/mamba.py: Auto-detection and use of Metal kernels on MPS device Fused kernels: 1. conv1d_silu_fused: Depthwise conv1d + SiLU in one pass (12x faster) 2. ssm_fused: SSM state prep + parallel scan 3. ssm_output_fused: Full SSM + output matmul (8x faster, avoids large intermediate) Installation: pip install -e . python setup_metal.py build_ext --inplace The kernels auto-activate when running on MPS device. Falls back to PyTorch on other devices.
Use stream->commandEncoder() instead of COMMIT_AND_WAIT for zero-overhead integration with PyTorch's MPS backend. Bumped to v1.2.1
- Register metal_pscan ops with torch.library for torch.compile compatibility - Add torch_ops.py with MPS and Meta implementations for all ops - Update mamba.py to use torch.ops.metal_pscan.* when available - Backend selection happens at import time (zero overhead in hot path) - Falls back to direct _C.* calls if torch_ops fails to load
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Native Metal compute shaders for ~18x speedup on Apple Silicon (M1/M2/M3/M4), making Mamba practical for training on Mac.
Performance
On M3 Max with batch=1024, seq_len=2, d_model=384:
Changes
metal_pscan/: Native Metal compute shaders with PyTorch autograd integrationsetup_metal.py: Build script for the Metal extensionmambapy/mamba.py: Auto-detection and use of Metal kernels on MPS deviceFused Kernels
Installation
pip install -e . python setup_metal.py build_ext --inplaceUsage
Kernels auto-activate when running on MPS device. No code changes needed:
Falls back gracefully to PyTorch on CUDA/CPU.
Tested On