-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
Hey! I've been using metal-flash-attention for a PyTorch training project on Apple Silicon and made some additions:
Fork: imperatormk/metal-flash-attention
Changes:
- ✅ Causal masking support (for autoregressive models)
- ✅ macOS 15 (Sequoia) compatibility fixes
- ✅ Tested with PyTorch MPS backend
PyTorch Bindings: imperatormk/mps-flash-attention
Drop-in Flash Attention for PyTorch on Apple Silicon:
- O(N) memory instead of O(N²)
- 2-5x faster than PyTorch SDPA
- Full backward pass for training
- Pre-compiled kernels for zero cold start
from mps_flash_attn import flash_attention
out = flash_attention(q, k, v, is_causal=True)If you're interested in merging any of these changes upstream, happy to open a PR!
Thanks for the original library - it's been super helpful for training on M1/M3 Macs.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels