@@ -1444,6 +1444,19 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1444
1444
# SDPAParams gained an additional argument in PyTorch 2.5
1445
1445
args = []
1446
1446
assert k_and_v .both_in_parallel ()
1447
+ # This is also done in `MultiHeadSelfAttention.scaled_dot_product_attention`
1448
+ if mask is None and enable_gqa :
1449
+ # Some efficient kernels have not implemented
1450
+ # `enabla_gqa=True`. It is better to extend keys, values in
1451
+ # this case.
1452
+ key = k_and_v .keys ()
1453
+ value = k_and_v .values ()
1454
+ q_per_kv = config .n_head // config .n_query_groups
1455
+ key = key .repeat_interleave (q_per_kv , dim = 1 )
1456
+ value = value .repeat_interleave (q_per_kv , dim = 1 )
1457
+ assert query .shape [1 ] == key .shape [1 ]
1458
+ k_and_v = DefaultKeysAndValues (key , value )
1459
+
1447
1460
if hasattr (SDPAParams , "enable_gqa" ):
1448
1461
args .append (enable_gqa )
1449
1462
params = SDPAParams (query , k_and_v .keys (), k_and_v .values (), mask , 0.0 , True , * args )
@@ -1506,6 +1519,7 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1506
1519
q_per_kv = config .n_head // config .n_query_groups
1507
1520
key = key .repeat_interleave (q_per_kv , dim = 1 )
1508
1521
value = value .repeat_interleave (q_per_kv , dim = 1 )
1522
+ assert query .shape [1 ] == key .shape [1 ]
1509
1523
k_and_v = DefaultKeysAndValues (key , value )
1510
1524
1511
1525
if hasattr (SDPAParams , "enable_gqa" ):
0 commit comments