Skip to content

Commit 0018b1b

Browse files
authored
PyTorch FA test fix (#370)
* Corrected bugs * Updated documentation
1 parent 610173a commit 0018b1b

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

tests/pytorch/fused_attn/test_fused_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def test():
220220
)
221221
(
222222
use_flash_attention,
223-
use_fused_attention,
224223
flash_attention_backend,
224+
use_fused_attention,
225225
fused_attention_backend,
226226
use_unfused_attention,
227227
available_backends,
@@ -369,6 +369,7 @@ def test_dot_product_attention(
369369
and config.attn_mask_type in ["causal", "padding_causal"]
370370
)
371371
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
372+
and not is_mla
372373
):
373374
flash_attn_supported = True
374375

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ def get_attention_backend(
278278
----------
279279
use_flash_attention: bool
280280
Whether the `FlashAttention` backend has been selected.
281+
flash_attention_backend: PkgVersion
282+
If `use_flash_attention = True`, the version of the selected `FlashAttention` backend.
281283
use_fused_attention: bool
282284
Whether the `FusedAttention` backend has been selected.
283285
fused_attention_backend: tex.NVTE_Fused_Attn_Backend

0 commit comments

Comments
 (0)