diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 43f1d4488..16791266e 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -219,8 +219,8 @@ def test(): ) ( use_flash_attention, - use_fused_attention, flash_attention_backend, + use_fused_attention, fused_attention_backend, use_unfused_attention, available_backends, @@ -368,6 +368,7 @@ def test_dot_product_attention( and config.attn_mask_type in ["causal", "padding_causal"] ) and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) + and not is_mla ): flash_attn_supported = True