Skip to content

Maintained fork with PyTorch bindings and macOS 15 support #32

@imperatormk

Description

@imperatormk

Hey! I've been using metal-flash-attention for a PyTorch training project on Apple Silicon and made some additions:

Fork: imperatormk/metal-flash-attention

Changes:

  • ✅ Causal masking support (for autoregressive models)
  • ✅ macOS 15 (Sequoia) compatibility fixes
  • ✅ Tested with PyTorch MPS backend

PyTorch Bindings: imperatormk/mps-flash-attention

Drop-in Flash Attention for PyTorch on Apple Silicon:

  • O(N) memory instead of O(N²)
  • 2-5x faster than PyTorch SDPA
  • Full backward pass for training
  • Pre-compiled kernels for zero cold start
from mps_flash_attn import flash_attention

out = flash_attention(q, k, v, is_causal=True)

If you're interested in merging any of these changes upstream, happy to open a PR!

Thanks for the original library - it's been super helpful for training on M1/M3 Macs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions