@@ -1470,7 +1470,8 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1470
1470
assert can_use_flash_attention (params , True ), "can_use_flash_attention(params, True) is False"
1471
1471
elif expected is SDPBackend .EFFICIENT_ATTENTION :
1472
1472
assert mem_efficient_sdp_enabled (), "mem_efficient_sdp_enabled() is False"
1473
- assert can_use_efficient_attention (params , True ), "can_use_efficient_attention(params, True) is False"
1473
+ if (not enable_gqa ) or mask is None :
1474
+ assert can_use_efficient_attention (params , True ), "can_use_efficient_attention(params, True) is False"
1474
1475
elif expected is SDPBackend .MATH :
1475
1476
assert math_sdp_enabled (), "math_sdp_enabled() is False"
1476
1477
else :
@@ -1538,7 +1539,8 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1538
1539
assert can_use_flash_attention (params , True ), "can_use_flash_attention(params, True) is False"
1539
1540
elif expected is SDPBackend .EFFICIENT_ATTENTION :
1540
1541
assert mem_efficient_sdp_enabled (), "mem_efficient_sdp_enabled() is False"
1541
- assert can_use_efficient_attention (params , True ), "can_use_efficient_attention(params, True) is False"
1542
+ if (not enable_gqa ) or mask is None :
1543
+ assert can_use_efficient_attention (params , True ), "can_use_efficient_attention(params, True) is False"
1542
1544
elif expected is SDPBackend .MATH :
1543
1545
assert math_sdp_enabled (), "math_sdp_enabled() is False"
1544
1546
else :
0 commit comments