Skip to content

Commit 529dfef

Browse files
committed
Fix kv cache test count
1 parent 2739d61 commit 529dfef

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/test_auto_fp8.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,14 @@ def test_kv_cache_static_quantization(model_id):
8080
model.save_quantized(quantized_model_dir)
8181

8282
tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors")
83-
count_matches = 0
84-
for name, tensor in tensors.items():
83+
proj_linear_count = 0
84+
output_scale_count = 0
85+
for name, _ in tensors.items():
86+
if name.endswith("k_proj") or name.endswith("v_proj"):
87+
proj_linear_count += 1
8588
if name.endswith("k_proj.output_scale") or name.endswith("v_proj.output_scale"):
86-
count_matches += 1
87-
assert count_matches == 24
89+
output_scale_count += 1
90+
assert proj_linear_count == output_scale_count
8891

8992
# Measure checkpoint size and cleanup
9093
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")

0 commit comments

Comments
 (0)