Skip to content

Commit 0eac983

Browse files
committed
Fix proj linear count
1 parent 0967345 commit 0eac983

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tests/test_auto_fp8.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ def test_dynamic_quantization():
3535
"Qwen/Qwen2-0.5B-Instruct",
3636
=======
3737
("facebook/opt-125m", 160),
38+
<<<<<<< HEAD
3839
("Qwen/Qwen2-0.5B-Instruct", 600),
3940
>>>>>>> 415c0b7 (Add fixed target sizes)
41+
=======
42+
("Qwen/Qwen2-0.5B-Instruct", 620),
43+
>>>>>>> 93c0d54 (Fix proj linear count)
4044
]
4145

4246
@pytest.mark.parametrize("model_id,target_size", MODELS)
@@ -180,7 +184,7 @@ def test_kv_cache_static_quantization(model_id, target_size):
180184
proj_linear_count = 0
181185
output_scale_count = 0
182186
for name, _ in tensors.items():
183-
if name.endswith("k_proj") or name.endswith("v_proj"):
187+
if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"):
184188
proj_linear_count += 1
185189
if name.endswith("k_proj.output_scale") or name.endswith("v_proj.output_scale"):
186190
output_scale_count += 1

0 commit comments

Comments
 (0)