Skip to content

Commit 134521d

Browse files
committed
Fix
1 parent d048904 commit 134521d

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/test_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,14 +1455,15 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14551455
key = key.repeat_interleave(q_per_kv, dim=1)
14561456
value = value.repeat_interleave(q_per_kv, dim=1)
14571457
assert query.shape[1] == key.shape[1]
1458-
k_and_v = DefaultKeysAndValues(key, value)
1458+
_k_and_v = DefaultKeysAndValues(key, value)
14591459
_enable_gqa = False
14601460
else:
14611461
_enable_gqa = enable_gqa
1462+
_k_and_v = k_and_v
14621463

14631464
if hasattr(SDPAParams, "enable_gqa"):
14641465
args.append(_enable_gqa)
1465-
params = SDPAParams(query, k_and_v.keys(), k_and_v.values(), mask, 0.0, True, *args)
1466+
params = SDPAParams(query, _k_and_v.keys(), _k_and_v.values(), mask, 0.0, True, *args)
14661467
if expected is SDPBackend.FLASH_ATTENTION:
14671468
assert flash_sdp_enabled(), "flash_sdp_enabled() is False"
14681469
if config.sliding_window_size is None:
@@ -1523,14 +1524,15 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
15231524
key = key.repeat_interleave(q_per_kv, dim=1)
15241525
value = value.repeat_interleave(q_per_kv, dim=1)
15251526
assert query.shape[1] == key.shape[1]
1526-
k_and_v = DefaultKeysAndValues(key, value)
1527+
_k_and_v = DefaultKeysAndValues(key, value)
15271528
_enable_gqa = False
15281529
else:
15291530
_enable_gqa = enable_gqa
1531+
_k_and_v = k_and_v
15301532

15311533
if hasattr(SDPAParams, "enable_gqa"):
15321534
args.append(_enable_gqa)
1533-
params = SDPAParams(query, k_and_v.keys(), k_and_v.values(), mask, 0.0, True, *args)
1535+
params = SDPAParams(query, _k_and_v.keys(), _k_and_v.values(), mask, 0.0, True, *args)
15341536
if expected is SDPBackend.FLASH_ATTENTION:
15351537
assert flash_sdp_enabled(), "flash_sdp_enabled() is False"
15361538
assert can_use_flash_attention(params, True), "can_use_flash_attention(params, True) is False"

0 commit comments

Comments
 (0)