File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
tritonbench/operators/blackwell_attentions Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change 67
67
try :
68
68
import xformers # @manual=//fair/xformers:xformers
69
69
import xformers .ops .fmha as xformers_fmha # @manual=//fair/xformers:xformers
70
+ from xformers .ops .fmha import MemoryEfficientAttentionCutlassBlackwellOp
70
71
71
72
from ..flash_attention .test_fmha_utils import permute_qkv
72
73
@@ -316,11 +317,11 @@ def cutlass_blackwell(
316
317
k : torch .Tensor ,
317
318
v : torch .Tensor ,
318
319
) -> Callable :
319
- need_gradient = not (self .mode == BenchmarkMode .FWD_NO_GRAD )
320
320
fhma_input = self .xformers_preprocess (q , k , v )
321
- xformers_cutlass_fhma = xformers .ops .fmha .cutlass_blackwell .FwOp
322
- return lambda : xformers_cutlass_fhma ().apply (
323
- fhma_input , needs_gradient = need_gradient
321
+
322
+ return lambda : xformers .ops .fmha ._memory_efficient_attention (
323
+ fhma_input ,
324
+ op = MemoryEfficientAttentionCutlassBlackwellOp ,
324
325
)
325
326
326
327
@register_benchmark (enabled = HAS_XFORMERS , fwd_only = True )
You can’t perform that action at this time.
0 commit comments