@@ -1456,9 +1456,12 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1456
1456
value = value .repeat_interleave (q_per_kv , dim = 1 )
1457
1457
assert query .shape [1 ] == key .shape [1 ]
1458
1458
k_and_v = DefaultKeysAndValues (key , value )
1459
+ _enable_gqa = False
1460
+ else :
1461
+ _enable_gqa = enable_gqa
1459
1462
1460
1463
if hasattr (SDPAParams , "enable_gqa" ):
1461
- args .append (enable_gqa )
1464
+ args .append (_enable_gqa )
1462
1465
params = SDPAParams (query , k_and_v .keys (), k_and_v .values (), mask , 0.0 , True , * args )
1463
1466
if expected is SDPBackend .FLASH_ATTENTION :
1464
1467
assert flash_sdp_enabled (), "flash_sdp_enabled() is False"
@@ -1521,9 +1524,12 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1521
1524
value = value .repeat_interleave (q_per_kv , dim = 1 )
1522
1525
assert query .shape [1 ] == key .shape [1 ]
1523
1526
k_and_v = DefaultKeysAndValues (key , value )
1527
+ _enable_gqa = False
1528
+ else :
1529
+ _enable_gqa = enable_gqa
1524
1530
1525
1531
if hasattr (SDPAParams , "enable_gqa" ):
1526
- args .append (enable_gqa )
1532
+ args .append (_enable_gqa )
1527
1533
params = SDPAParams (query , k_and_v .keys (), k_and_v .values (), mask , 0.0 , True , * args )
1528
1534
if expected is SDPBackend .FLASH_ATTENTION :
1529
1535
assert flash_sdp_enabled (), "flash_sdp_enabled() is False"
0 commit comments