Skip to content

Commit b88e346

Browse files
committed
Fix
1 parent 54edeca commit b88e346

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/test_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,13 +1435,14 @@ def test_sdpa_choice(config):
14351435
pytest.skip("Gemma 2 doesn't support SDPA")
14361436

14371437
torch.set_default_dtype(torch.float16)
1438+
enable_gqa = config["n_query_groups"] < config["n_head"]
14381439

14391440
def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14401441
# SDPAParams gained an additional argument in PyTorch 2.5
14411442
args = []
14421443
assert k_and_v.both_in_parallel()
14431444
if hasattr(SDPAParams, "enable_gqa"):
1444-
args.append(False)
1445+
args.append(enable_gqa)
14451446
params = SDPAParams(query, k_and_v.keys(), k_and_v.values(), mask, 0.0, True, *args)
14461447
if expected is SDPBackend.FLASH_ATTENTION:
14471448
assert flash_sdp_enabled(), "flash_sdp_enabled() is False"
@@ -1487,13 +1488,14 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14871488
@torch.inference_mode()
14881489
def test_sdpa_choice_kv_cache(config):
14891490
torch.set_default_dtype(torch.float16)
1491+
enable_gqa = config["n_query_groups"] < config["n_head"]
14901492

14911493
def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14921494
# SDPAParams gained an additional argument in PyTorch 2.5
14931495
args = []
14941496
assert k_and_v.both_in_parallel()
14951497
if hasattr(SDPAParams, "enable_gqa"):
1496-
args.append(False)
1498+
args.append(enable_gqa)
14971499
params = SDPAParams(query, k_and_v.keys(), k_and_v.values(), mask, 0.0, True, *args)
14981500
if expected is SDPBackend.FLASH_ATTENTION:
14991501
assert flash_sdp_enabled()

0 commit comments

Comments
 (0)