Skip to content

Commit 37051e5

Browse files
committed
Fix test
1 parent 98dfe11 commit 37051e5

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

tests/test_model.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,7 +1438,7 @@ def test_sdpa_choice(config):
14381438

14391439
torch.set_default_dtype(torch.float16)
14401440

1441-
def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logit_softcapping):
1441+
def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14421442
# SDPAParams gained an additional argument in PyTorch 2.5
14431443
args = []
14441444
assert k_and_v.both_in_parallel()
@@ -1456,7 +1456,7 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
14561456
assert math_sdp_enabled(), "math_sdp_enabled() is False"
14571457
else:
14581458
raise NotImplementedError
1459-
return original_fn(query, k_and_v, scale, mask, attention_logit_softcapping)
1459+
return original_fn(query, k_and_v, mask, return_scores)
14601460

14611461
config["n_layer"] = 1
14621462
config = config_module.Config(**config)
@@ -1469,10 +1469,9 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
14691469
# best effort, if the GPU can load it
14701470
pytest.xfail()
14711471

1472-
for h in model.transformer.h:
1473-
litgpt.attention.scaled_dot_product_attention = partial(
1474-
assert_sdpa_backend, litgpt.attention.scaled_dot_product_attention
1475-
)
1472+
model.mha.scaled_dot_product_attention = partial(
1473+
assert_sdpa_backend, model.mha.scaled_dot_product_attention,
1474+
)
14761475

14771476
if SUPPORTS_FLASH_ATTENTION:
14781477
expected = SDPBackend.FLASH_ATTENTION
@@ -1490,7 +1489,7 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
14901489
def test_sdpa_choice_kv_cache(config):
14911490
torch.set_default_dtype(torch.float16)
14921491

1493-
def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logit_softcapping):
1492+
def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14941493
# SDPAParams gained an additional argument in PyTorch 2.5
14951494
args = []
14961495
assert k_and_v.both_in_parallel()
@@ -1507,7 +1506,7 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
15071506
assert math_sdp_enabled()
15081507
else:
15091508
raise NotImplementedError
1510-
return original_fn(query, k_and_v, scale, mask, attention_logit_softcapping)
1509+
return original_fn(query, k_and_v, mask, return_scores)
15111510

15121511
config["n_layer"] = 1
15131512
config = config_module.Config(**config)
@@ -1522,10 +1521,9 @@ def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logi
15221521
# best effort, if the GPU can load it
15231522
pytest.xfail()
15241523

1525-
for h in model.transformer.h:
1526-
litgpt.attention.scaled_dot_product_attention = partial(
1527-
assert_sdpa_backend, litgpt.attention.scaled_dot_product_attention
1528-
)
1524+
model.mha.scaled_dot_product_attention = partial(
1525+
assert_sdpa_backend, model.mha.scaled_dot_product_attention,
1526+
)
15291527

15301528
if SUPPORTS_FLASH_ATTENTION:
15311529
# flash attention does not support an attention mask

0 commit comments

Comments
 (0)