@@ -65,15 +65,8 @@ def is_contiguous(x: Tensor):
6565
6666# kernels
6767
68- @triton .heuristics (
69- {
70- "EVEN_M" : lambda args : divisible_by (args ["seqlen_q" ], args ["BLOCK" ]),
71- "EVEN_N" : lambda args : divisible_by (args ["seqlen_k" ], args ["BLOCK" ]),
72- "EVEN_HEADDIM" : lambda args : args ["headdim" ] == args ["BLOCK_HEADDIM" ],
73- }
74- )
7568@triton .jit
76- def forward_kernel (
69+ def forward_kernel_causal_and_sparse (
7770 Q ,
7871 K ,
7972 V ,
@@ -415,6 +408,99 @@ def forward_kernel(
415408 mask = (offs_m [:, None , None ] < seqlen_q ) & (offs_d [None , None , :] < headdim )
416409 )
417410
411+ @triton .heuristics (
412+ {
413+ "EVEN_M" : lambda args : divisible_by (args ["seqlen_q" ], args ["BLOCK" ]),
414+ "EVEN_N" : lambda args : divisible_by (args ["seqlen_k" ], args ["BLOCK" ]),
415+ "EVEN_HEADDIM" : lambda args : args ["headdim" ] == args ["BLOCK_HEADDIM" ],
416+ }
417+ )
418+ @triton .jit
419+ def forward_kernel (
420+ Q ,
421+ K ,
422+ V ,
423+ kv_block_indices ,
424+ kv_block_mask ,
425+ Out ,
426+ Lse ,
427+ softmax_scale ,
428+ stride_qb ,
429+ stride_qh ,
430+ stride_qm ,
431+ stride_kb ,
432+ stride_kh ,
433+ stride_kn ,
434+ stride_vb ,
435+ stride_vh ,
436+ stride_vn ,
437+ stride_ob ,
438+ stride_oh ,
439+ stride_om ,
440+ stride_kvbl_b ,
441+ stride_kvbl_h ,
442+ stride_kvbl_m ,
443+ stride_lse_b ,
444+ kv_heads ,
445+ seqlen_q ,
446+ seqlen_k ,
447+ seqlen_q_rounded ,
448+ headdim ,
449+ CACHE_KEY_SEQLEN_Q ,
450+ CACHE_KEY_SEQLEN_K ,
451+ BLOCK_HEADDIM : tl .constexpr ,
452+ EVEN_M : tl .constexpr ,
453+ EVEN_N : tl .constexpr ,
454+ EVEN_HEADDIM : tl .constexpr ,
455+ BLOCK : tl .constexpr ,
456+ QUERY_HEAD_GROUPS : tl .constexpr ,
457+ QUERY_EXPAND_DIM : tl .constexpr ,
458+ NUM_SEL_KV_BLOCKS : tl .constexpr ,
459+ INCLUDE_BLOCK_CAUSAL : tl .constexpr
460+ ):
461+ forward_kernel_causal_and_sparse (
462+ Q ,
463+ K ,
464+ V ,
465+ kv_block_indices ,
466+ kv_block_mask ,
467+ Out ,
468+ Lse ,
469+ softmax_scale ,
470+ stride_qb ,
471+ stride_qh ,
472+ stride_qm ,
473+ stride_kb ,
474+ stride_kh ,
475+ stride_kn ,
476+ stride_vb ,
477+ stride_vh ,
478+ stride_vn ,
479+ stride_ob ,
480+ stride_oh ,
481+ stride_om ,
482+ stride_kvbl_b ,
483+ stride_kvbl_h ,
484+ stride_kvbl_m ,
485+ stride_lse_b ,
486+ kv_heads ,
487+ seqlen_q ,
488+ seqlen_k ,
489+ seqlen_q_rounded ,
490+ headdim ,
491+ CACHE_KEY_SEQLEN_Q ,
492+ CACHE_KEY_SEQLEN_K ,
493+ BLOCK_HEADDIM ,
494+ EVEN_M ,
495+ EVEN_N ,
496+ EVEN_HEADDIM ,
497+ BLOCK ,
498+ QUERY_HEAD_GROUPS ,
499+ QUERY_EXPAND_DIM ,
500+ NUM_SEL_KV_BLOCKS ,
501+ INCLUDE_BLOCK_CAUSAL
502+ )
503+
418504def native_sparse_attn_forward (
419505 q ,
420506 k ,
0 commit comments