Skip to content

Commit d048904

Browse files
committed
Fix
1 parent d659391 commit d048904

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tests/test_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,9 +1456,12 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14561456
value = value.repeat_interleave(q_per_kv, dim=1)
14571457
assert query.shape[1] == key.shape[1]
14581458
k_and_v = DefaultKeysAndValues(key, value)
1459+
_enable_gqa = False
1460+
else:
1461+
_enable_gqa = enable_gqa
14591462

14601463
if hasattr(SDPAParams, "enable_gqa"):
1461-
args.append(enable_gqa)
1464+
args.append(_enable_gqa)
14621465
params = SDPAParams(query, k_and_v.keys(), k_and_v.values(), mask, 0.0, True, *args)
14631466
if expected is SDPBackend.FLASH_ATTENTION:
14641467
assert flash_sdp_enabled(), "flash_sdp_enabled() is False"
@@ -1521,9 +1524,12 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
15211524
value = value.repeat_interleave(q_per_kv, dim=1)
15221525
assert query.shape[1] == key.shape[1]
15231526
k_and_v = DefaultKeysAndValues(key, value)
1527+
_enable_gqa = False
1528+
else:
1529+
_enable_gqa = enable_gqa
15241530

15251531
if hasattr(SDPAParams, "enable_gqa"):
1526-
args.append(enable_gqa)
1532+
args.append(_enable_gqa)
15271533
params = SDPAParams(query, k_and_v.keys(), k_and_v.values(), mask, 0.0, True, *args)
15281534
if expected is SDPBackend.FLASH_ATTENTION:
15291535
assert flash_sdp_enabled(), "flash_sdp_enabled() is False"

0 commit comments

Comments
 (0)