Skip to content

Commit 35a21af

Browse files
committed
Fix kv cache test count
1 parent 5831ba9 commit 35a21af

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/test_auto_fp8.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,14 @@ def test_kv_cache_static_quantization(model_id):
158158
model.save_quantized(quantized_model_dir)
159159

160160
tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors")
161-
count_matches = 0
162-
for name, tensor in tensors.items():
161+
proj_linear_count = 0
162+
output_scale_count = 0
163+
for name, _ in tensors.items():
164+
if name.endswith("k_proj") or name.endswith("v_proj"):
165+
proj_linear_count += 1
163166
if name.endswith("k_proj.output_scale") or name.endswith("v_proj.output_scale"):
164-
count_matches += 1
165-
assert count_matches == 24
167+
output_scale_count += 1
168+
assert proj_linear_count == output_scale_count
166169

167170
# Measure checkpoint size and cleanup
168171
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")

0 commit comments

Comments
 (0)