Skip to content

Commit d445048

Browse files
committed
Fix
1 parent 134521d commit d445048

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/test_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,7 +1470,8 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14701470
assert can_use_flash_attention(params, True), "can_use_flash_attention(params, True) is False"
14711471
elif expected is SDPBackend.EFFICIENT_ATTENTION:
14721472
assert mem_efficient_sdp_enabled(), "mem_efficient_sdp_enabled() is False"
1473-
assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False"
1473+
if (not enable_gqa) or mask is None:
1474+
assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False"
14741475
elif expected is SDPBackend.MATH:
14751476
assert math_sdp_enabled(), "math_sdp_enabled() is False"
14761477
else:
@@ -1538,7 +1539,8 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
15381539
assert can_use_flash_attention(params, True), "can_use_flash_attention(params, True) is False"
15391540
elif expected is SDPBackend.EFFICIENT_ATTENTION:
15401541
assert mem_efficient_sdp_enabled(), "mem_efficient_sdp_enabled() is False"
1541-
assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False"
1542+
if (not enable_gqa) or mask is None:
1543+
assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False"
15421544
elif expected is SDPBackend.MATH:
15431545
assert math_sdp_enabled(), "math_sdp_enabled() is False"
15441546
else:

0 commit comments

Comments
 (0)