Skip to content

Commit ffd81cc

Browse files
committed
Remove deprecated tests
Signed-off-by: SimJeg <[email protected]>
1 parent d52e7b1 commit ffd81cc

File tree

2 files changed

+12
-78
lines changed

2 files changed

+12
-78
lines changed

tests/presses/test_wrappers.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

tests/test_pipeline.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@
1717
from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_unit_test_pipeline # noqa: F401
1818

1919

20-
# def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811
21-
# with caplog.at_level(logging.DEBUG):
22-
# context = "This is a test article. It was written on 2022-01-01."
23-
# questions = ["When was this article written?"]
24-
# press = ExpectedAttentionPress(compression_ratio=0.4)
25-
# answers = kv_press_unit_test_pipeline(context, questions=questions, press=press)["answers"]
20+
def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811
21+
with caplog.at_level(logging.DEBUG):
22+
context = "This is a test article. It was written on 2022-01-01."
23+
questions = ["When was this article written?"]
24+
press = ExpectedAttentionPress(compression_ratio=0.4)
25+
answers = kv_press_unit_test_pipeline(context, questions=questions, press=press)["answers"]
2626

27-
# assert len(answers) == 1
28-
# assert isinstance(answers[0], str)
27+
assert len(answers) == 1
28+
assert isinstance(answers[0], str)
2929

30-
# messages = [record.message for record in caplog.records]
31-
# assert "Context Length: 23" in messages, messages
32-
# assert "Compressed Context Length: 13" in messages, messages
30+
messages = [record.message for record in caplog.records]
31+
assert "Context Length: 23" in messages, messages
32+
assert "Compressed Context Length: 13" in messages, messages
3333

3434

3535
def test_pipeline_with_cache(kv_press_unit_test_pipeline): # noqa: F811
@@ -47,6 +47,7 @@ class TestPipelineFA2:
4747
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
4848
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
4949
@pytest.mark.parametrize("compression_ratio", [0.0, 0.2])
50+
@pytest.mark.xfail(reason="Known issue not related to kvpress", strict=False)
5051
def test_pipeline_fa2(self, kv_press_llama3_2_flash_attn_pipeline, compression_ratio): # noqa: F811
5152
context = "This is a test article. It was written on 2022-01-01."
5253
questions = ["Repeat the last sentence"]

0 commit comments

Comments
 (0)