Skip to content

Commit 5831ba9

Browse files
committed
Add Qwen test
1 parent f934b0e commit 5831ba9

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

auto_fp8/modeling.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,15 @@ def skip(*args, **kwargs):
113113

114114
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
115115
<<<<<<< HEAD
116+
<<<<<<< HEAD
116117
=======
117118
def _prepare_calibration_data(calibration_tokens):
118119
if hasattr(calibration_tokens, "input_ids"):
119120
return calibration_tokens.input_ids
120121
return calibration_tokens
121122
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
123+
=======
124+
>>>>>>> 2739d61 (Add Qwen test)
122125

123126
# Always quantize the weights as they do not require calibration data
124127
quantize_weights(self.model, self.quantize_config)
@@ -128,15 +131,21 @@ def _prepare_calibration_data(calibration_tokens):
128131
calibration_tokens is not None
129132
), "Calibration tokens required for activation quantization"
130133
<<<<<<< HEAD
134+
<<<<<<< HEAD
135+
=======
136+
>>>>>>> 2739d61 (Add Qwen test)
131137

132138

133139
def _prepare_calibration_data(calibration_tokens):
134140
if hasattr(calibration_tokens, "input_ids"):
135141
return calibration_tokens.input_ids
136142
return calibration_tokens
137143

144+
<<<<<<< HEAD
138145
=======
139146
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
147+
=======
148+
>>>>>>> 2739d61 (Add Qwen test)
140149
quantize_activations(
141150
self.model,
142151
self.quantize_config,

tests/test_auto_fp8.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
import os
22
import shutil
33

4+
<<<<<<< HEAD
45
<<<<<<< HEAD
56
import pytest
67
=======
78
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
9+
=======
10+
import pytest
11+
>>>>>>> 2739d61 (Add Qwen test)
812
import safetensors.torch
913
from transformers import AutoTokenizer
1014

1115
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
1216

1317
MODELS = [
18+
<<<<<<< HEAD
1419
("facebook/opt-125m", 160),
1520
("Qwen/Qwen2-0.5B-Instruct", 620),
1621
]
@@ -24,6 +29,15 @@ def test_dynamic_quantization():
2429
model_id = "facebook/opt-125m"
2530
quantized_model_dir = "opt-125m-fp8-dynamic"
2631
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
32+
=======
33+
"facebook/opt-125m",
34+
"Qwen/Qwen2-0.5B-Instruct",
35+
]
36+
37+
@pytest.mark.parametrize("model_id", MODELS)
38+
def test_dynamic_quantization(model_id):
39+
quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic"
40+
>>>>>>> 2739d61 (Add Qwen test)
2741

2842
quantize_config = BaseQuantizeConfig(
2943
quant_method="fp8", activation_scheme="dynamic"
@@ -54,10 +68,16 @@ def test_static_quantization(model_id, target_size):
5468
assert model_size < target_size
5569

5670

71+
<<<<<<< HEAD
5772
def test_static_quantization():
5873
model_id = "facebook/opt-125m"
5974
quantized_model_dir = "opt-125m-fp8-static"
6075
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
76+
=======
77+
@pytest.mark.parametrize("model_id", MODELS)
78+
def test_static_quantization(model_id):
79+
quantized_model_dir = model_id.split("/")[-1] + "-fp8-static"
80+
>>>>>>> 2739d61 (Add Qwen test)
6181

6282
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
6383
examples = ["auto-fp8 is an easy-to-use model quantization library"]
@@ -117,10 +137,9 @@ def test_kv_cache_static_quantization(model_id, target_size):
117137
target_size = target_size * (1024 * 1024)
118138
assert model_size < target_size
119139

120-
121-
def test_kv_cache_static_quantization():
122-
model_id = "facebook/opt-125m"
123-
quantized_model_dir = "opt-125m-fp8-static-kv"
140+
@pytest.mark.parametrize("model_id", MODELS)
141+
def test_kv_cache_static_quantization(model_id):
142+
quantized_model_dir = model_id.split("/")[-1] + "-fp8-static-kv"
124143

125144
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
126145
examples = ["auto-fp8 is an easy-to-use model quantization library"]

0 commit comments

Comments
 (0)