|
77 | 77 | except (ImportError, IOError, AttributeError):
|
78 | 78 | HAS_AITER = False
|
79 | 79 |
|
| 80 | +# [Optional] cutlass_blackwell_fmha backend |
| 81 | +HAS_CUTLASS_BLACKWELL = True |
| 82 | +try: |
| 83 | + from fbgemm_gpu.experimental.gen_ai.attention.cutlass_blackwell_fmha import ( |
| 84 | + cutlass_blackwell_fmha_interface as blackwell, |
| 85 | + ) |
| 86 | + # Disable FA3 for Blackwell as it doesn't work properly |
| 87 | + HAS_FLASH_V3 = False |
| 88 | + # Note: We keep FA2 and triton enabled alongside Blackwell for comparison |
| 89 | +except (ImportError, IOError, AttributeError): |
| 90 | + HAS_CUTLASS_BLACKWELL = False |
| 91 | + |
80 | 92 |
|
81 | 93 | # [Optional] flash_fwd cute-DSL backend
|
82 | 94 | HAS_FLASH_CUTE = True
|
@@ -591,6 +603,54 @@ def flash_cute_dsl(
|
591 | 603 | q, k_cache, v_cache, causal=CAUSAL, pack_gqa=(q_heads != kv_heads)
|
592 | 604 | )
|
593 | 605 |
|
| 606 | + @register_benchmark(enabled=HAS_CUTLASS_BLACKWELL) |
| 607 | + def cutlass_blackwell_fmha_decode_fp8qkv( |
| 608 | + self, |
| 609 | + q: torch.Tensor, |
| 610 | + k_cache: torch.Tensor, |
| 611 | + v_cache: torch.Tensor, |
| 612 | + cache_seqlens: torch.Tensor, |
| 613 | + ) -> Callable: |
| 614 | + seq_len_q = q.shape[1] |
| 615 | + |
| 616 | + # Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1) |
| 617 | + if seq_len_q != 1: |
| 618 | + # Skip non-decode cases for now |
| 619 | + raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case") |
| 620 | + # return lambda: q.new_zeros(q.shape) |
| 621 | + |
| 622 | + # Convert to fp8 format as required by the decode path |
| 623 | + _q = q.to(torch.float8_e4m3fn) |
| 624 | + _k_cache = k_cache.to(torch.float8_e4m3fn) |
| 625 | + _v_cache = v_cache.to(torch.float8_e4m3fn) |
| 626 | + |
| 627 | + # Create seqlen_kv tensor for generation phase |
| 628 | + seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device) |
| 629 | + |
| 630 | + return lambda: blackwell.cutlass_blackwell_fmha_func( |
| 631 | + _q, _k_cache, _v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv |
| 632 | + ) |
| 633 | + @register_benchmark(enabled=HAS_CUTLASS_BLACKWELL) |
| 634 | + def cutlass_blackwell_fmha_decode( |
| 635 | + self, |
| 636 | + q: torch.Tensor, |
| 637 | + k_cache: torch.Tensor, |
| 638 | + v_cache: torch.Tensor, |
| 639 | + cache_seqlens: torch.Tensor, |
| 640 | + ) -> Callable: |
| 641 | + seq_len_q = q.shape[1] |
| 642 | + |
| 643 | + # Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1) |
| 644 | + if seq_len_q != 1: |
| 645 | + # Skip non-decode cases for now |
| 646 | + raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case") |
| 647 | + |
| 648 | + # Create seqlen_kv tensor for generation phase |
| 649 | + seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device) |
| 650 | + |
| 651 | + return lambda: blackwell.cutlass_blackwell_fmha_func( |
| 652 | + q, k_cache, v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv |
| 653 | + ) |
594 | 654 | @register_benchmark(enabled=HAS_AITER)
|
595 | 655 | def aiter_paged_fp8kv(
|
596 | 656 | self,
|
|
0 commit comments