@@ -1435,13 +1435,14 @@ def test_sdpa_choice(config):
1435
1435
pytest .skip ("Gemma 2 doesn't support SDPA" )
1436
1436
1437
1437
torch .set_default_dtype (torch .float16 )
1438
+ enable_gqa = config ["n_query_groups" ] < config ["n_head" ]
1438
1439
1439
1440
def assert_sdpa_backend (original_fn , query , k_and_v , mask , return_scores ):
1440
1441
# SDPAParams gained an additional argument in PyTorch 2.5
1441
1442
args = []
1442
1443
assert k_and_v .both_in_parallel ()
1443
1444
if hasattr (SDPAParams , "enable_gqa" ):
1444
- args .append (False )
1445
+ args .append (enable_gqa )
1445
1446
params = SDPAParams (query , k_and_v .keys (), k_and_v .values (), mask , 0.0 , True , * args )
1446
1447
if expected is SDPBackend .FLASH_ATTENTION :
1447
1448
assert flash_sdp_enabled (), "flash_sdp_enabled() is False"
@@ -1487,13 +1488,14 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
1487
1488
@torch .inference_mode ()
1488
1489
def test_sdpa_choice_kv_cache (config ):
1489
1490
torch .set_default_dtype (torch .float16 )
1491
+ enable_gqa = config ["n_query_groups" ] < config ["n_head" ]
1490
1492
1491
1493
def assert_sdpa_backend (original_fn , query , k_and_v , mask , return_scores ):
1492
1494
# SDPAParams gained an additional argument in PyTorch 2.5
1493
1495
args = []
1494
1496
assert k_and_v .both_in_parallel ()
1495
1497
if hasattr (SDPAParams , "enable_gqa" ):
1496
- args .append (False )
1498
+ args .append (enable_gqa )
1497
1499
params = SDPAParams (query , k_and_v .keys (), k_and_v .values (), mask , 0.0 , True , * args )
1498
1500
if expected is SDPBackend .FLASH_ATTENTION :
1499
1501
assert flash_sdp_enabled ()
0 commit comments