@@ -1455,14 +1455,15 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1455
1455
key = key .repeat_interleave (q_per_kv , dim = 1 )
1456
1456
value = value .repeat_interleave (q_per_kv , dim = 1 )
1457
1457
assert query .shape [1 ] == key .shape [1 ]
1458
- k_and_v = DefaultKeysAndValues (key , value )
1458
+ _k_and_v = DefaultKeysAndValues (key , value )
1459
1459
_enable_gqa = False
1460
1460
else :
1461
1461
_enable_gqa = enable_gqa
1462
+ _k_and_v = k_and_v
1462
1463
1463
1464
if hasattr (SDPAParams , "enable_gqa" ):
1464
1465
args .append (_enable_gqa )
1465
- params = SDPAParams (query , k_and_v .keys (), k_and_v .values (), mask , 0.0 , True , * args )
1466
+ params = SDPAParams (query , _k_and_v .keys (), _k_and_v .values (), mask , 0.0 , True , * args )
1466
1467
if expected is SDPBackend .FLASH_ATTENTION :
1467
1468
assert flash_sdp_enabled (), "flash_sdp_enabled() is False"
1468
1469
if config .sliding_window_size is None :
@@ -1523,14 +1524,15 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1523
1524
key = key .repeat_interleave (q_per_kv , dim = 1 )
1524
1525
value = value .repeat_interleave (q_per_kv , dim = 1 )
1525
1526
assert query .shape [1 ] == key .shape [1 ]
1526
- k_and_v = DefaultKeysAndValues (key , value )
1527
+ _k_and_v = DefaultKeysAndValues (key , value )
1527
1528
_enable_gqa = False
1528
1529
else :
1529
1530
_enable_gqa = enable_gqa
1531
+ _k_and_v = k_and_v
1530
1532
1531
1533
if hasattr (SDPAParams , "enable_gqa" ):
1532
1534
args .append (_enable_gqa )
1533
- params = SDPAParams (query , k_and_v .keys (), k_and_v .values (), mask , 0.0 , True , * args )
1535
+ params = SDPAParams (query , _k_and_v .keys (), _k_and_v .values (), mask , 0.0 , True , * args )
1534
1536
if expected is SDPBackend .FLASH_ATTENTION :
1535
1537
assert flash_sdp_enabled (), "flash_sdp_enabled() is False"
1536
1538
assert can_use_flash_attention (params , True ), "can_use_flash_attention(params, True) is False"
0 commit comments