Skip to content

Commit 26a8892

Browse files
committed
indirection needed for forward triton kernel in preparation to dispatch sliding window attn + fine attention in parallel next week
1 parent 1f11855 commit 26a8892

File tree

1 file changed

+94
-8
lines changed

1 file changed

+94
-8
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,8 @@ def is_contiguous(x: Tensor):
6565

6666
# kernels
6767

68-
@triton.heuristics(
69-
{
70-
"EVEN_M": lambda args: divisible_by(args["seqlen_q"], args["BLOCK"]),
71-
"EVEN_N": lambda args: divisible_by(args["seqlen_k"], args["BLOCK"]),
72-
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
73-
}
74-
)
7568
@triton.jit
76-
def forward_kernel(
69+
def forward_kernel_causal_and_sparse(
7770
Q,
7871
K,
7972
V,
@@ -415,6 +408,99 @@ def forward_kernel(
415408
mask = (offs_m[:, None, None] < seqlen_q) & (offs_d[None, None, :] < headdim)
416409
)
417410

411+
@triton.heuristics(
412+
{
413+
"EVEN_M": lambda args: divisible_by(args["seqlen_q"], args["BLOCK"]),
414+
"EVEN_N": lambda args: divisible_by(args["seqlen_k"], args["BLOCK"]),
415+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
416+
}
417+
)
418+
@triton.jit
419+
def forward_kernel(
420+
Q,
421+
K,
422+
V,
423+
kv_block_indices,
424+
kv_block_mask,
425+
Out,
426+
Lse,
427+
softmax_scale,
428+
stride_qb,
429+
stride_qh,
430+
stride_qm,
431+
stride_kb,
432+
stride_kh,
433+
stride_kn,
434+
stride_vb,
435+
stride_vh,
436+
stride_vn,
437+
stride_ob,
438+
stride_oh,
439+
stride_om,
440+
stride_kvbl_b,
441+
stride_kvbl_h,
442+
stride_kvbl_m,
443+
stride_lse_b,
444+
kv_heads,
445+
seqlen_q,
446+
seqlen_k,
447+
seqlen_q_rounded,
448+
headdim,
449+
CACHE_KEY_SEQLEN_Q,
450+
CACHE_KEY_SEQLEN_K,
451+
BLOCK_HEADDIM: tl.constexpr,
452+
EVEN_M: tl.constexpr,
453+
EVEN_N: tl.constexpr,
454+
EVEN_HEADDIM: tl.constexpr,
455+
BLOCK: tl.constexpr,
456+
QUERY_HEAD_GROUPS: tl.constexpr,
457+
QUERY_EXPAND_DIM: tl.constexpr,
458+
NUM_SEL_KV_BLOCKS: tl.constexpr,
459+
INCLUDE_BLOCK_CAUSAL: tl.constexpr
460+
):
461+
forward_kernel_causal_and_sparse(
462+
Q,
463+
K,
464+
V,
465+
kv_block_indices,
466+
kv_block_mask,
467+
Out,
468+
Lse,
469+
softmax_scale,
470+
stride_qb,
471+
stride_qh,
472+
stride_qm,
473+
stride_kb,
474+
stride_kh,
475+
stride_kn,
476+
stride_vb,
477+
stride_vh,
478+
stride_vn,
479+
stride_ob,
480+
stride_oh,
481+
stride_om,
482+
stride_kvbl_b,
483+
stride_kvbl_h,
484+
stride_kvbl_m,
485+
stride_lse_b,
486+
kv_heads,
487+
seqlen_q,
488+
seqlen_k,
489+
seqlen_q_rounded,
490+
headdim,
491+
CACHE_KEY_SEQLEN_Q,
492+
CACHE_KEY_SEQLEN_K,
493+
BLOCK_HEADDIM,
494+
EVEN_M,
495+
EVEN_N,
496+
EVEN_HEADDIM,
497+
BLOCK,
498+
QUERY_HEAD_GROUPS,
499+
QUERY_EXPAND_DIM,
500+
NUM_SEL_KV_BLOCKS,
501+
INCLUDE_BLOCK_CAUSAL
502+
)
503+
418504
def native_sparse_attn_forward(
419505
q,
420506
k,

0 commit comments

Comments
 (0)