Skip to content

Commit 38595c4

Browse files
committed
Fix
1 parent 73bdd62 commit 38595c4

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/test_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,10 +1520,12 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
15201520
# best effort, if the GPU can load it
15211521
pytest.xfail()
15221522

1523-
model.mha.scaled_dot_product_attention = partial(
1524-
assert_sdpa_backend,
1525-
model.mha.scaled_dot_product_attention,
1526-
)
1523+
for block in model.transformer.h:
1524+
kv_cache = block.attn.kv_cache
1525+
kv_cache.mha.scaled_dot_product_attention = partial(
1526+
assert_sdpa_backend,
1527+
kv_cache.mha.scaled_dot_product_attention,
1528+
)
15271529

15281530
if SUPPORTS_FLASH_ATTENTION:
15291531
# flash attention does not support an attention mask

0 commit comments

Comments
 (0)