Skip to content

Commit 0967345

Browse files
committed
Add fixed target sizes
1 parent 35a21af commit 0967345

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

tests/test_auto_fp8.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
1616

1717
MODELS = [
18+
<<<<<<< HEAD
1819
<<<<<<< HEAD
1920
("facebook/opt-125m", 160),
2021
("Qwen/Qwen2-0.5B-Instruct", 620),
@@ -32,10 +33,14 @@ def test_dynamic_quantization():
3233
=======
3334
"facebook/opt-125m",
3435
"Qwen/Qwen2-0.5B-Instruct",
36+
=======
37+
("facebook/opt-125m", 160),
38+
("Qwen/Qwen2-0.5B-Instruct", 600),
39+
>>>>>>> 415c0b7 (Add fixed target sizes)
3540
]
3641

37-
@pytest.mark.parametrize("model_id", MODELS)
38-
def test_dynamic_quantization(model_id):
42+
@pytest.mark.parametrize("model_id,target_size", MODELS)
43+
def test_dynamic_quantization(model_id, target_size):
3944
quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic"
4045
>>>>>>> 2739d61 (Add Qwen test)
4146

@@ -53,6 +58,7 @@ def test_dynamic_quantization(model_id):
5358
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
5459
shutil.rmtree(quantized_model_dir)
5560

61+
<<<<<<< HEAD
5662
<<<<<<< HEAD
5763
# We expect the quantized model to be a certain size
5864
target_size = target_size * (1024 * 1024)
@@ -76,6 +82,15 @@ def test_static_quantization():
7682
=======
7783
@pytest.mark.parametrize("model_id", MODELS)
7884
def test_static_quantization(model_id):
85+
=======
86+
# We expect the model to be a certain size
87+
target_size = target_size * (1024 * 1024)
88+
assert model_size < target_size
89+
90+
91+
@pytest.mark.parametrize("model_id,target_size", MODELS)
92+
def test_static_quantization(model_id, target_size):
93+
>>>>>>> 415c0b7 (Add fixed target sizes)
7994
quantized_model_dir = model_id.split("/")[-1] + "-fp8-static"
8095
>>>>>>> 2739d61 (Add Qwen test)
8196

@@ -95,6 +110,7 @@ def test_static_quantization(model_id):
95110
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
96111
shutil.rmtree(quantized_model_dir)
97112

113+
<<<<<<< HEAD
98114
# We expect the quantized model to be a certain size
99115
target_size = target_size * (1024 * 1024)
100116
assert model_size < target_size
@@ -134,11 +150,14 @@ def test_kv_cache_static_quantization(model_id, target_size):
134150
shutil.rmtree(quantized_model_dir)
135151

136152
# We expect the quantized model to be a certain size
153+
=======
154+
# We expect the model to be < 160MB
155+
>>>>>>> 415c0b7 (Add fixed target sizes)
137156
target_size = target_size * (1024 * 1024)
138157
assert model_size < target_size
139158

140-
@pytest.mark.parametrize("model_id", MODELS)
141-
def test_kv_cache_static_quantization(model_id):
159+
@pytest.mark.parametrize("model_id,target_size", MODELS)
160+
def test_kv_cache_static_quantization(model_id, target_size):
142161
quantized_model_dir = model_id.split("/")[-1] + "-fp8-static-kv"
143162

144163
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
@@ -172,5 +191,5 @@ def test_kv_cache_static_quantization(model_id):
172191
shutil.rmtree(quantized_model_dir)
173192

174193
# We expect the model to be < 160MB
175-
target_size = 160 * (1024 * 1024)
194+
target_size = target_size * (1024 * 1024)
176195
assert model_size < target_size

0 commit comments

Comments
 (0)