Skip to content

Commit 7e4f294

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
Add cutlass decode kernel to TritonBench
Summary: X-link: pytorch/FBGEMM#4853 X-link: facebookresearch/FBGEMM#1875 Add cutlass blackwell FMHA decode kernel implementation to TritonBench benchmarking suite . Reviewed By: sryap Differential Revision: D80041532
1 parent d152a59 commit 7e4f294

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

tritonbench/operators/decoding_attention/operator.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,18 @@
7777
except (ImportError, IOError, AttributeError):
7878
HAS_AITER = False
7979

80+
# [Optional] cutlass_blackwell_fmha backend
81+
HAS_CUTLASS_BLACKWELL = True
82+
try:
83+
from fbgemm_gpu.experimental.gen_ai.attention.cutlass_blackwell_fmha import (
84+
cutlass_blackwell_fmha_interface as blackwell,
85+
)
86+
# Disable FA3 for Blackwell as it doesn't work properly
87+
HAS_FLASH_V3 = False
88+
# Note: We keep FA2 and triton enabled alongside Blackwell for comparison
89+
except (ImportError, IOError, AttributeError):
90+
HAS_CUTLASS_BLACKWELL = False
91+
8092

8193
# [Optional] flash_fwd cute-DSL backend
8294
HAS_FLASH_CUTE = True
@@ -591,6 +603,54 @@ def flash_cute_dsl(
591603
q, k_cache, v_cache, causal=CAUSAL, pack_gqa=(q_heads != kv_heads)
592604
)
593605

606+
@register_benchmark(enabled=HAS_CUTLASS_BLACKWELL)
607+
def cutlass_blackwell_fmha_decode_fp8qkv(
608+
self,
609+
q: torch.Tensor,
610+
k_cache: torch.Tensor,
611+
v_cache: torch.Tensor,
612+
cache_seqlens: torch.Tensor,
613+
) -> Callable:
614+
seq_len_q = q.shape[1]
615+
616+
# Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1)
617+
if seq_len_q != 1:
618+
# Skip non-decode cases for now
619+
raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case")
620+
# return lambda: q.new_zeros(q.shape)
621+
622+
# Convert to fp8 format as required by the decode path
623+
_q = q.to(torch.float8_e4m3fn)
624+
_k_cache = k_cache.to(torch.float8_e4m3fn)
625+
_v_cache = v_cache.to(torch.float8_e4m3fn)
626+
627+
# Create seqlen_kv tensor for generation phase
628+
seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device)
629+
630+
return lambda: blackwell.cutlass_blackwell_fmha_func(
631+
_q, _k_cache, _v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv
632+
)
633+
@register_benchmark(enabled=HAS_CUTLASS_BLACKWELL)
634+
def cutlass_blackwell_fmha_decode(
635+
self,
636+
q: torch.Tensor,
637+
k_cache: torch.Tensor,
638+
v_cache: torch.Tensor,
639+
cache_seqlens: torch.Tensor,
640+
) -> Callable:
641+
seq_len_q = q.shape[1]
642+
643+
# Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1)
644+
if seq_len_q != 1:
645+
# Skip non-decode cases for now
646+
raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case")
647+
648+
# Create seqlen_kv tensor for generation phase
649+
seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device)
650+
651+
return lambda: blackwell.cutlass_blackwell_fmha_func(
652+
q, k_cache, v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv
653+
)
594654
@register_benchmark(enabled=HAS_AITER)
595655
def aiter_paged_fp8kv(
596656
self,

0 commit comments

Comments
 (0)