Skip to content

Commit 2060dfb

Browse files
committed
Use Qwen3-0.6B
Signed-off-by: SimJeg <[email protected]>
1 parent c76b827 commit 2060dfb

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

tests/fixtures.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,38 @@ def get_device():
1414

1515
@pytest.fixture(scope="session")
1616
def unit_test_model():
17-
model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
17+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B").eval()
1818
return model.to(get_device())
1919

2020

2121
@pytest.fixture(scope="session")
2222
def unit_test_model_output_attention():
2323
model = AutoModelForCausalLM.from_pretrained(
24-
"MaxJeblick/llama2-0b-unit-test", attn_implementation="eager"
24+
"Qwen/Qwen3-0.6B", attn_implementation="eager"
2525
).eval()
2626
return model.to(get_device())
2727

2828

2929
@pytest.fixture(scope="session")
30-
def danube_500m_model():
31-
model = AutoModelForCausalLM.from_pretrained("h2oai/h2o-danube3-500m-chat").eval()
30+
def qwen3_600m_model():
31+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B").eval()
3232
return model.to(get_device())
3333

3434

3535
@pytest.fixture(scope="session")
3636
def kv_press_unit_test_pipeline():
3737
return pipeline(
3838
"kv-press-text-generation",
39-
model="maxjeblick/llama2-0b-unit-test",
39+
model="Qwen/Qwen3-0.6B",
4040
device=get_device(),
4141
)
4242

4343

4444
@pytest.fixture(scope="session")
45-
def kv_press_danube_pipeline():
45+
def kv_press_qwen3_600m_pipeline():
4646
return pipeline(
4747
"kv-press-text-generation",
48-
model="h2oai/h2o-danube3-500m-chat",
48+
model="Qwen/Qwen3-0.6B",
4949
device=get_device(),
5050
)
5151

tests/test_decoding_compression.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_decoding_compression(token_buffer_size):
3131
"""Test that DecodingPress compresses the cache during decoding."""
3232

3333
# Initialize pipeline with a small model
34-
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
34+
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
3535

3636
# Create a DecodingPress with KnormPress
3737
press = DecodingPress(
@@ -65,7 +65,7 @@ def test_prefill_decoding_press_calls_both_phases():
6565
"""Test that PrefillDecodingPress calls both prefilling and decoding presses."""
6666

6767
# Initialize pipeline
68-
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
68+
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
6969

7070
# Create PrefillDecodingPress with both presses
7171
combined_press = PrefillDecodingPress(
@@ -99,7 +99,7 @@ def test_decoding_press_without_prefill():
9999
"""Test that DecodingPress works correctly when used standalone (no prefill compression)."""
100100

101101
# Initialize pipeline
102-
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
102+
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
103103

104104
# Create DecodingPress only
105105
decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64)
@@ -129,7 +129,7 @@ def test_prefill_decoding_press_decoding_only():
129129
"""Test PrefillDecodingPress with only decoding press (no prefill compression)."""
130130

131131
# Initialize pipeline
132-
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
132+
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
133133

134134
# Create PrefillDecodingPress with only decoding press
135135
combined_press = PrefillDecodingPress(
@@ -167,7 +167,7 @@ def test_decoding_press_equivalence():
167167
torch.manual_seed(42)
168168

169169
# Initialize pipeline
170-
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
170+
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
171171

172172
# Create standalone decoding press
173173
decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52)
@@ -222,7 +222,7 @@ def test_all_presses_work_with_decoding_press(press_config):
222222
"""Test that all default presses work as base presses for DecodingPress."""
223223

224224
# Initialize pipeline
225-
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
225+
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
226226

227227
# Get press class and use the first (easier) configuration
228228
press_cls = press_config["cls"]
@@ -274,7 +274,7 @@ def test_all_presses_work_with_decoding_press(press_config):
274274
def test_compression_actually_reduces_memory():
275275
"""Test that compression actually reduces memory usage compared to no compression."""
276276

277-
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
277+
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
278278

279279
context = "The quick brown fox jumps over the lazy dog. " * 15 # Long context
280280
question = "What animal jumps over the dog?"

tests/test_pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from kvpress import ExpectedAttentionPress
1313
from kvpress.pipeline import KVPressTextGenerationPipeline
14-
from tests.fixtures import danube_500m_model # noqa: F401
15-
from tests.fixtures import kv_press_danube_pipeline # noqa: F401
14+
from tests.fixtures import qwen3_600m_model # noqa: F401
15+
from tests.fixtures import kv_press_qwen3_600m_pipeline # noqa: F401
1616
from tests.fixtures import unit_test_model # noqa: F401
1717
from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_unit_test_pipeline # noqa: F401
1818

@@ -94,9 +94,9 @@ def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog): # noqa:
9494
kv_press_unit_test_pipeline(context, question=question)
9595

9696

97-
def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811
97+
def test_pipeline_answer_is_correct(qwen3_600m_model, caplog): # noqa: F811
9898
with caplog.at_level(logging.DEBUG):
99-
answers = generate_answer(danube_500m_model)
99+
answers = generate_answer(qwen3_600m_model)
100100

101101
for answer in answers:
102102
assert answer == "This article was written on January 1, 2022."
@@ -107,13 +107,13 @@ def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811
107107

108108

109109
@pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available")
110-
def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noqa: F811
110+
def test_pipeline_with_quantized_cache(kv_press_qwen3_600m_pipeline, caplog): # noqa: F811
111111
with caplog.at_level(logging.DEBUG):
112112
context = "This is a test article. It was written on 2022-01-01."
113113
questions = ["When was this article written?"]
114114
press = ExpectedAttentionPress(compression_ratio=0.4)
115-
cache = QuantoQuantizedCache(config=kv_press_danube_pipeline.model.config, nbits=4)
116-
answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"]
115+
cache = QuantoQuantizedCache(config=kv_press_qwen3_600m_pipeline.model.config, nbits=4)
116+
answers = kv_press_qwen3_600m_pipeline(context, questions=questions, press=press, cache=cache)["answers"]
117117

118118
assert len(answers) == 1
119119
assert isinstance(answers[0], str)

0 commit comments

Comments
 (0)