@@ -1438,7 +1438,7 @@ def test_sdpa_choice(config):
1438
1438
1439
1439
torch .set_default_dtype (torch .float16 )
1440
1440
1441
- def assert_sdpa_backend (original_fn , query , k_and_v , scale , mask , attention_logit_softcapping ):
1441
+ def assert_sdpa_backend (original_fn , query , k_and_v , mask , return_scores ):
1442
1442
# SDPAParams gained an additional argument in PyTorch 2.5
1443
1443
args = []
1444
1444
assert k_and_v .both_in_parallel ()
@@ -1456,7 +1456,7 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
1456
1456
assert math_sdp_enabled (), "math_sdp_enabled() is False"
1457
1457
else :
1458
1458
raise NotImplementedError
1459
- return original_fn (query , k_and_v , scale , mask , attention_logit_softcapping )
1459
+ return original_fn (query , k_and_v , mask , return_scores )
1460
1460
1461
1461
config ["n_layer" ] = 1
1462
1462
config = config_module .Config (** config )
@@ -1469,10 +1469,9 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
1469
1469
# best effort, if the GPU can load it
1470
1470
pytest .xfail ()
1471
1471
1472
- for h in model .transformer .h :
1473
- litgpt .attention .scaled_dot_product_attention = partial (
1474
- assert_sdpa_backend , litgpt .attention .scaled_dot_product_attention
1475
- )
1472
+ model .mha .scaled_dot_product_attention = partial (
1473
+ assert_sdpa_backend , model .mha .scaled_dot_product_attention ,
1474
+ )
1476
1475
1477
1476
if SUPPORTS_FLASH_ATTENTION :
1478
1477
expected = SDPBackend .FLASH_ATTENTION
@@ -1490,7 +1489,7 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
1490
1489
def test_sdpa_choice_kv_cache (config ):
1491
1490
torch .set_default_dtype (torch .float16 )
1492
1491
1493
- def assert_sdpa_backend (original_fn , query , k_and_v , scale , mask , attention_logit_softcapping ):
1492
+ def assert_sdpa_backend (original_fn , query , k_and_v , mask , return_scores ):
1494
1493
# SDPAParams gained an additional argument in PyTorch 2.5
1495
1494
args = []
1496
1495
assert k_and_v .both_in_parallel ()
@@ -1507,7 +1506,7 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
1507
1506
assert math_sdp_enabled ()
1508
1507
else :
1509
1508
raise NotImplementedError
1510
- return original_fn (query , k_and_v , scale , mask , attention_logit_softcapping )
1509
+ return original_fn (query , k_and_v , mask , return_scores )
1511
1510
1512
1511
config ["n_layer" ] = 1
1513
1512
config = config_module .Config (** config )
@@ -1522,10 +1521,9 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
1522
1521
# best effort, if the GPU can load it
1523
1522
pytest .xfail ()
1524
1523
1525
- for h in model .transformer .h :
1526
- litgpt .attention .scaled_dot_product_attention = partial (
1527
- assert_sdpa_backend , litgpt .attention .scaled_dot_product_attention
1528
- )
1524
+ model .mha .scaled_dot_product_attention = partial (
1525
+ assert_sdpa_backend , model .mha .scaled_dot_product_attention ,
1526
+ )
1529
1527
1530
1528
if SUPPORTS_FLASH_ATTENTION :
1531
1529
# flash attention does not support an attention mask
0 commit comments