Skip to content

Commit d92eac6

Browse files
committed
Fix
1 parent b88e346 commit d92eac6

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

tests/test_model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,8 @@ def test_sdpa_choice(config):
14351435
pytest.skip("Gemma 2 doesn't support SDPA")
14361436

14371437
torch.set_default_dtype(torch.float16)
1438+
config["n_layer"] = 1
1439+
config = config_module.Config(**config)
14381440
enable_gqa = config["n_query_groups"] < config["n_head"]
14391441

14401442
def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
@@ -1457,9 +1459,6 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14571459
raise NotImplementedError
14581460
return original_fn(query, k_and_v, mask, return_scores)
14591461

1460-
config["n_layer"] = 1
1461-
config = config_module.Config(**config)
1462-
14631462
try:
14641463
with torch.device("cuda"):
14651464
model = GPT(config)
@@ -1488,6 +1487,8 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14881487
@torch.inference_mode()
14891488
def test_sdpa_choice_kv_cache(config):
14901489
torch.set_default_dtype(torch.float16)
1490+
config["n_layer"] = 1
1491+
config = config_module.Config(**config)
14911492
enable_gqa = config["n_query_groups"] < config["n_head"]
14921493

14931494
def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
@@ -1509,9 +1510,6 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
15091510
raise NotImplementedError
15101511
return original_fn(query, k_and_v, mask, return_scores)
15111512

1512-
config["n_layer"] = 1
1513-
config = config_module.Config(**config)
1514-
15151513
try:
15161514
with torch.device("cuda"):
15171515
model = GPT(config)

0 commit comments

Comments
 (0)