Skip to content

Commit ddc8cd5

Browse files
committed
Adds varlen Flash-DM attention (CUDA) + autograd
Introduces CUDA custom ops and fake/meta registrations for variable-length attention forward/backward using cumulative sequence lengths, enabling efficient packed ragged batches. Adds an autograd wrapper and public API exposing varlen attention with support for MQA/GQA, optional mask/bias, causal mode, softcap, deterministic backward, and optional attention probs/LSE for testing. Pads head dim and key seqlen to multiples of 8 for 16‑bit–friendly allocations, rounds workspace shapes, and sanitizes outputs; also supports paged KV via a block table. Improves performance and memory by avoiding per-sequence padding and aligning allocations to hardware-friendly sizes.
1 parent 060471f commit ddc8cd5

File tree

1 file changed

+407
-0
lines changed

1 file changed

+407
-0
lines changed

0 commit comments

Comments
 (0)