Skip to content

Commit 415c0b7

Browse files
committed
Add fixed target sizes
1 parent 529dfef commit 415c0b7

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

tests/test_auto_fp8.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
99

1010
MODELS = [
11-
"facebook/opt-125m",
12-
"Qwen/Qwen2-0.5B-Instruct",
11+
("facebook/opt-125m", 160),
12+
("Qwen/Qwen2-0.5B-Instruct", 600),
1313
]
1414

15-
@pytest.mark.parametrize("model_id", MODELS)
16-
def test_dynamic_quantization(model_id):
15+
@pytest.mark.parametrize("model_id,target_size", MODELS)
16+
def test_dynamic_quantization(model_id, target_size):
1717
quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic"
1818

1919
quantize_config = BaseQuantizeConfig(
@@ -30,13 +30,13 @@ def test_dynamic_quantization(model_id):
3030
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
3131
shutil.rmtree(quantized_model_dir)
3232

33-
# We expect the model to be < 160MB
34-
target_size = 160 * (1024 * 1024)
33+
# We expect the model to be a certain size
34+
target_size = target_size * (1024 * 1024)
3535
assert model_size < target_size
3636

3737

38-
@pytest.mark.parametrize("model_id", MODELS)
39-
def test_static_quantization(model_id):
38+
@pytest.mark.parametrize("model_id,target_size", MODELS)
39+
def test_static_quantization(model_id, target_size):
4040
quantized_model_dir = model_id.split("/")[-1] + "-fp8-static"
4141

4242
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
@@ -56,11 +56,11 @@ def test_static_quantization(model_id):
5656
shutil.rmtree(quantized_model_dir)
5757

5858
# We expect the model to be < 160MB
59-
target_size = 160 * (1024 * 1024)
59+
target_size = target_size * (1024 * 1024)
6060
assert model_size < target_size
6161

62-
@pytest.mark.parametrize("model_id", MODELS)
63-
def test_kv_cache_static_quantization(model_id):
62+
@pytest.mark.parametrize("model_id,target_size", MODELS)
63+
def test_kv_cache_static_quantization(model_id, target_size):
6464
quantized_model_dir = model_id.split("/")[-1] + "-fp8-static-kv"
6565

6666
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
@@ -94,5 +94,5 @@ def test_kv_cache_static_quantization(model_id):
9494
shutil.rmtree(quantized_model_dir)
9595

9696
# We expect the model to be < 160MB
97-
target_size = 160 * (1024 * 1024)
97+
target_size = target_size * (1024 * 1024)
9898
assert model_size < target_size

0 commit comments

Comments
 (0)