Skip to content

Commit 22856a7

Browse files
committed
Fix
1 parent d92eac6 commit 22856a7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/test_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,7 +1437,7 @@ def test_sdpa_choice(config):
14371437
torch.set_default_dtype(torch.float16)
14381438
config["n_layer"] = 1
14391439
config = config_module.Config(**config)
1440-
enable_gqa = config["n_query_groups"] < config["n_head"]
1440+
enable_gqa = config.n_query_groups < config.n_head
14411441

14421442
def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14431443
# SDPAParams gained an additional argument in PyTorch 2.5
@@ -1489,7 +1489,7 @@ def test_sdpa_choice_kv_cache(config):
14891489
torch.set_default_dtype(torch.float16)
14901490
config["n_layer"] = 1
14911491
config = config_module.Config(**config)
1492-
enable_gqa = config["n_query_groups"] < config["n_head"]
1492+
enable_gqa = config.n_query_groups < config.n_head
14931493

14941494
def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14951495
# SDPAParams gained an additional argument in PyTorch 2.5

0 commit comments

Comments
 (0)