@@ -1471,6 +1471,8 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1471
1471
elif expected is SDPBackend .EFFICIENT_ATTENTION :
1472
1472
assert mem_efficient_sdp_enabled (), "mem_efficient_sdp_enabled() is False"
1473
1473
if (not enable_gqa ) or mask is None :
1474
+ # At present, `SDPBackend.EFFICIENT_ATTENTION` does not support
1475
+ # `enabla_gqa=True` and a mask specified
1474
1476
assert can_use_efficient_attention (params , True ), "can_use_efficient_attention(params, True) is False"
1475
1477
elif expected is SDPBackend .MATH :
1476
1478
assert math_sdp_enabled (), "math_sdp_enabled() is False"
@@ -1540,6 +1542,8 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1540
1542
elif expected is SDPBackend .EFFICIENT_ATTENTION :
1541
1543
assert mem_efficient_sdp_enabled (), "mem_efficient_sdp_enabled() is False"
1542
1544
if (not enable_gqa ) or mask is None :
1545
+ # At present, `SDPBackend.EFFICIENT_ATTENTION` does not support
1546
+ # `enabla_gqa=True` and a mask specified
1543
1547
assert can_use_efficient_attention (params , True ), "can_use_efficient_attention(params, True) is False"
1544
1548
elif expected is SDPBackend .MATH :
1545
1549
assert math_sdp_enabled (), "math_sdp_enabled() is False"
0 commit comments