Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
**/*.npz
old_tests/
*.egg-info/
build/
build/metal_pscan/*.so
metal_pscan/__pycache__/
build/
175 changes: 175 additions & 0 deletions METAL_MAMBA_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Metal Mamba

Native Metal implementation of Mamba's parallel scan for Apple Silicon (M1/M2/M3/M4).

## Performance

```
============================================================
Metal PScan vs PyTorch PScan on MPS
============================================================

Config: B=2, D=128, N=16

Seq Len | Metal (ms) | PyTorch (ms) | Speedup
----------------------------------------------------
64 | 0.654 | 0.869 | 1.33x
128 | 0.694 | 1.344 | 1.94x
256 | 0.772 | 2.132 | 2.76x
512 | 0.548 | 3.407 | 6.21x
1024 | 1.024 | 6.359 | 6.21x
```

Metal parallelizes across all B*D*N slices while PyTorch has Python overhead.

## How it Works

The parallel scan computes:
```
H[t] = A[t] * H[t-1] + X[t] with H[0] = X[0]
```

This is the core operation in Mamba's selective scan mechanism.

### Algorithm: Blelloch Scan

We implement Blelloch's work-efficient parallel prefix scan:

1. **Up-sweep (Reduce)**: Combine pairs of elements up the tree
2. **Down-sweep**: Propagate partial sums back down

For a sequence of length L, this requires only O(log L) parallel steps instead of O(L) sequential steps.

### Metal Optimizations

- **SIMD shuffle operations**: Use `simd_shuffle_up` for efficient intra-warp communication
- **Threadgroup memory**: Cache intermediate results for cross-SIMD communication
- **Coalesced memory access**: Process (B, D, N) slices in parallel across sequence length
- **Function constants**: Compile-time specialization for each tensor shape

## Project Structure

```
mamba-metal/
├── mambapy/ # Original mamba.py (pure PyTorch)
│ ├── mamba.py # Mamba model
│ └── pscan.py # PyTorch parallel scan
├── metal/ # Metal implementation
│ ├── Package.swift
│ └── Sources/
│ ├── MetalMamba/
│ │ ├── pscan.metal # Metal shader for parallel scan
│ │ └── PScanKernel.swift
│ ├── MetalMambaBridge/
│ │ └── MambaBridge.swift # C-callable interface
│ └── TestPScan/
│ └── main.swift # Swift test
└── metal_pscan/ # Python wrapper
└── __init__.py # PyTorch integration
```

## Installation

### Prerequisites

- macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
- Xcode Command Line Tools
- Python 3.10+ with PyTorch 2.0+

### Build

```bash
# Clone the repo
git clone https://github.com/alxndrTL/mamba.py.git mamba-metal
cd mamba-metal

# Build the Metal library
cd metal
swift build -c release
cd ..

# Set environment variable (add to ~/.zshrc)
export METAL_MAMBA_BRIDGE_PATH=/path/to/mamba-metal/metal/.build/release/libMetalMambaBridge.dylib
```

## Usage

### Python (Drop-in replacement)

```python
from metal_pscan import metal_pscan, is_available

if is_available():
# Same interface as mambapy.pscan.pscan
H = metal_pscan(A, X)
```

### With Mamba model

```python
import sys
sys.path.insert(0, '/path/to/mamba-metal')

from mambapy.mamba import Mamba, MambaConfig

# Patch pscan to use Metal
from metal_pscan import metal_pscan, is_available
if is_available():
import mambapy.pscan
mambapy.pscan.pscan = metal_pscan

# Create and use Mamba model
config = MambaConfig(d_model=256, n_layers=4)
model = Mamba(config).to('mps')

x = torch.randn(2, 1024, 256, device='mps')
y = model(x)
```

### Swift (Direct)

```swift
import MetalMamba

let kernel = try PScanKernel()
let config = PScanConfig(batchSize: 2, seqLen: 1024, dInner: 256, dState: 16)
try kernel.compile(config: config)

// Create buffers and run
try kernel.forward(A: A_buf, X: X_buf, H: H_buf, useSIMD: true)
```

## Shader Details

The Metal shader (`pscan.metal`) implements:

1. **`pscan_forward`**: Basic parallel scan using threadgroup memory
2. **`pscan_forward_simd`**: Optimized version using SIMD shuffle operations
3. **`pscan_backward`**: Gradient computation (reverse scan)

Key optimization: Each threadgroup processes one `(batch, d_inner, d_state)` element across the entire sequence, enabling maximum parallelism.

## Current Status

**Working:**
- Forward pass (verified against CPU reference)
- SIMD-optimized kernel
- Python wrapper with PyTorch integration

**TODO:**
- Metal backward pass (currently falls back to PyTorch)
- Direct MPS tensor integration (avoid CPU round-trip)
- Half precision (fp16) support
- Benchmark with full Mamba model

## Credits

- [mamba.py](https://github.com/alxndrTL/mamba.py) by alxndrTL - Pure PyTorch Mamba implementation
- [Mamba paper](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao
- [Blelloch scan](https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf) algorithm

## License

MIT
55 changes: 55 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ A straightfoward implementation of [Mamba](https://arxiv.org/abs/2312.00752) in
It combines the ease of read with good performances when training. Few other functionalities are implemented, like [Jamba](https://www.ai21.com/blog/announcing-jamba), [Vision Mamba](https://arxiv.org/abs/2401.09417) as well as [muP](https://arxiv.org/abs/2203.03466).

## Updates
- <b>28/01/2025</b> : Added Metal (Apple Silicon MPS) fused kernels for ~18x speedup on Mac. See [Metal Acceleration](#metal-acceleration-apple-silicon) section below.
- <b>03/08/2024</b> : Added a muP implementation for Mamba and Mamba2. This allows to sweep for optimal hyperparameters on a small model and directly transfer them to a large model. See [this PR](https://github.com/alxndrTL/mamba.py/pull/50)
- <b>23/07/2024</b> : `mamba.py` is now part of the transformers 🤗 library. See [this PR](https://github.com/huggingface/transformers/pull/30139).
- <b>27/06/2024</b> : Deployed a package version of `mamba.py` on PyPI, which you can install with `pip install mambapy`.
Expand Down Expand Up @@ -39,6 +40,7 @@ This repo contains a simple and readable code implementing the [Mamba](https://a
- `vim.py` : an implementation of [Vision Mamba](https://arxiv.org/abs/2401.09417).
- `📁 onnx` : export a trained Mamba model in ONNX for inference.
- `📁 mlx` : basically the same code as above, but in MLX.
- `📁 metal_pscan` : Metal (Apple Silicon) fused kernels for MPS acceleration.
- `📁 docs` : a folder containing annotated explanations about the code, focusing on the parallel scan for now.
- `📁 examples` : two examples of how to use the Mamba model in PyTorch as well as a training file.

Expand Down Expand Up @@ -210,6 +212,59 @@ But memory requierement should also be considered : the official Mamba implement
Hence, this repo implements one of the three techniques mentionned in the Mamba paper that form the so called "hardware-aware selective scan" : the parallel scan.
We say how kernel fusion impacts the speed while recomputation the memory requierements.

___
## Metal Acceleration (Apple Silicon)

This fork adds custom Metal compute shaders for fast Mamba on Apple Silicon (M1/M2/M3/M4). The kernels are auto-detected when running on MPS device.

### Installation

```bash
# Clone and build
git clone https://github.com/your-repo/mamba-metal
cd mamba-metal
pip install -e .

# Build Metal extension
python setup_metal.py build_ext --inplace
```

### Usage

```python
import torch
from mambapy.mamba import Mamba, MambaConfig, _USE_METAL_SSM

print(f"Metal acceleration: {_USE_METAL_SSM}") # True on Apple Silicon

config = MambaConfig(d_model=384, n_layers=1)
model = Mamba(config).to('mps')

x = torch.randn(32, 16, 384, device='mps')
y = model(x) # Uses Metal kernels automatically
```

### Performance

On M3 Max with batch=1024, seq_len=2, d_model=384:

| Implementation | Time | vs Attention |
|----------------|------|--------------|
| PyTorch MPS (baseline) | ~120ms | 57x slower |
| Metal fused kernels | ~7ms | 3.2x slower |

**~18x speedup** over naive PyTorch MPS implementation.

### Fused Kernels

Three custom Metal kernels are provided:

1. **`conv1d_silu_fused`** - Fuses depthwise conv1d + SiLU activation (12x faster than PyTorch)
2. **`ssm_fused`** - Fuses SSM state preparation + parallel scan
3. **`ssm_output_fused`** - Super-fused SSM + output matmul (8x faster, avoids large intermediate tensor)

The kernels use Metal function constants for shape specialization and integrate with PyTorch's MPS stream.

___
## Sources and where to learn more
- the [Mamba paper](https://arxiv.org/abs/2312.00752) : describes the Mamba architecture as implemented in this repo, which allows to model sequences in linear time.
Expand Down
52 changes: 52 additions & 0 deletions bench_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python3
"""Benchmark Metal pscan scaling."""
import torch
import time
import sys
sys.path.insert(0, '/Users/zimski/projects/oss/mamba-metal')

import metal_pscan._C as _C
from mambapy.pscan import pscan as pytorch_pscan

print("=" * 60)
print("Metal PScan vs PyTorch PScan on MPS")
print("=" * 60)

B, D, N = 2, 128, 16
print(f"\nConfig: B={B}, D={D}, N={N}")
print(f"\n{'Seq Len':>8} | {'Metal (ms)':>12} | {'PyTorch (ms)':>13} | {'Speedup':>8}")
print("-" * 52)

for L in [64, 128, 256, 512, 1024]:
A = torch.rand(B, L, D, N, device='mps', dtype=torch.float32) * 0.4 + 0.5
X = torch.randn(B, L, D, N, device='mps', dtype=torch.float32)
torch.mps.synchronize()

# Warmup
for _ in range(10):
_ = _C.forward(A, X)
_ = pytorch_pscan(A, X)
torch.mps.synchronize()

iters = 100

# Metal
torch.mps.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
_ = _C.forward(A, X)
torch.mps.synchronize()
metal_ms = (time.perf_counter() - t0) / iters * 1000

# PyTorch
torch.mps.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
_ = pytorch_pscan(A, X)
torch.mps.synchronize()
pytorch_ms = (time.perf_counter() - t0) / iters * 1000

speedup = pytorch_ms / metal_ms
print(f"{L:>8} | {metal_ms:>12.3f} | {pytorch_ms:>13.3f} | {speedup:>7.2f}x")

print()
Loading