Skip to content

Add Metal (MPS) acceleration for Apple Silicon - 18x speedup#78

Open
imperatormk wants to merge 3 commits intoalxndrTL:mainfrom
imperatormk:feature/metal-mps-acceleration
Open

Add Metal (MPS) acceleration for Apple Silicon - 18x speedup#78
imperatormk wants to merge 3 commits intoalxndrTL:mainfrom
imperatormk:feature/metal-mps-acceleration

Conversation

@imperatormk
Copy link

Summary

Native Metal compute shaders for ~18x speedup on Apple Silicon (M1/M2/M3/M4), making Mamba practical for training on Mac.

Performance

On M3 Max with batch=1024, seq_len=2, d_model=384:

Implementation Time vs Attention
PyTorch MPS (baseline) ~120ms 57x slower
Metal fused kernels ~7ms 3.2x slower

Changes

  • metal_pscan/: Native Metal compute shaders with PyTorch autograd integration
  • setup_metal.py: Build script for the Metal extension
  • mambapy/mamba.py: Auto-detection and use of Metal kernels on MPS device

Fused Kernels

  1. conv1d_silu_fused: Depthwise conv1d + SiLU in one pass (12x faster)
  2. ssm_fused: SSM state prep + parallel scan
  3. ssm_output_fused: Full SSM + output matmul (8x faster, avoids large intermediate tensor)

Installation

pip install -e .
python setup_metal.py build_ext --inplace

Usage

Kernels auto-activate when running on MPS device. No code changes needed:

from mambapy.mamba import Mamba, MambaConfig

config = MambaConfig(d_model=384, n_layers=2)
model = Mamba(config).to('mps')  # Metal kernels used automatically

x = torch.randn(32, 16, 384, device='mps')
y = model(x)

Falls back gracefully to PyTorch on CUDA/CPU.

Tested On

  • M1 Pro, M3 Max
  • PyTorch 2.1+
  • macOS Sonoma/Sequoia

~18x speedup on M1/M2/M3/M4 Macs, making Mamba practical for training on Apple Silicon.

Performance (M3 Max, batch=1024, seq_len=2, d_model=384):
- PyTorch MPS baseline: ~120ms (57x slower than attention)
- Metal fused kernels: ~7ms (3.2x slower than attention)

Changes:
- metal_pscan/: Native Metal compute shaders with PyTorch autograd integration
- setup_metal.py: Build script for the Metal extension
- mambapy/mamba.py: Auto-detection and use of Metal kernels on MPS device

Fused kernels:
1. conv1d_silu_fused: Depthwise conv1d + SiLU in one pass (12x faster)
2. ssm_fused: SSM state prep + parallel scan
3. ssm_output_fused: Full SSM + output matmul (8x faster, avoids large intermediate)

Installation:
  pip install -e .
  python setup_metal.py build_ext --inplace

The kernels auto-activate when running on MPS device. Falls back to PyTorch on other devices.
Use stream->commandEncoder() instead of COMMIT_AND_WAIT for zero-overhead integration with PyTorch's MPS backend.

Bumped to v1.2.1
- Register metal_pscan ops with torch.library for torch.compile compatibility
- Add torch_ops.py with MPS and Meta implementations for all ops
- Update mamba.py to use torch.ops.metal_pscan.* when available
- Backend selection happens at import time (zero overhead in hot path)
- Falls back to direct _C.* calls if torch_ops fails to load
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant