Skip to content

Commit a0c6d82

Browse files
authored
Call _memory_efficient_attention for bwd of cutlass blackwell fmha as well
Differential Revision: D82490887 Pull Request resolved: meta-pytorch#423
1 parent ae51ff5 commit a0c6d82

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
try:
6868
import xformers # @manual=//fair/xformers:xformers
6969
import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers
70+
from xformers.ops.fmha import MemoryEfficientAttentionCutlassBlackwellOp
7071

7172
from ..flash_attention.test_fmha_utils import permute_qkv
7273

@@ -316,11 +317,11 @@ def cutlass_blackwell(
316317
k: torch.Tensor,
317318
v: torch.Tensor,
318319
) -> Callable:
319-
need_gradient = not (self.mode == BenchmarkMode.FWD_NO_GRAD)
320320
fhma_input = self.xformers_preprocess(q, k, v)
321-
xformers_cutlass_fhma = xformers.ops.fmha.cutlass_blackwell.FwOp
322-
return lambda: xformers_cutlass_fhma().apply(
323-
fhma_input, needs_gradient=need_gradient
321+
322+
return lambda: xformers.ops.fmha._memory_efficient_attention(
323+
fhma_input,
324+
op=MemoryEfficientAttentionCutlassBlackwellOp,
324325
)
325326

326327
@register_benchmark(enabled=HAS_XFORMERS, fwd_only=True)

0 commit comments

Comments
 (0)