Skip to content

Commit 2739d61

Browse files
committed
Add Qwen test
1 parent def2049 commit 2739d61

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

auto_fp8/modeling.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,6 @@ def skip(*args, **kwargs):
108108
return cls(model, quantize_config)
109109

110110
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
111-
def _prepare_calibration_data(calibration_tokens):
112-
if hasattr(calibration_tokens, "input_ids"):
113-
return calibration_tokens.input_ids
114-
return calibration_tokens
115111

116112
# Always quantize the weights as they do not require calibration data
117113
quantize_weights(self.model, self.quantize_config)
@@ -120,16 +116,19 @@ def _prepare_calibration_data(calibration_tokens):
120116
assert (
121117
calibration_tokens is not None
122118
), "Calibration tokens required for activation quantization"
119+
120+
121+
def _prepare_calibration_data(calibration_tokens):
122+
if hasattr(calibration_tokens, "input_ids"):
123+
return calibration_tokens.input_ids
124+
return calibration_tokens
125+
123126
quantize_activations(
124127
self.model,
125128
self.quantize_config,
126129
_prepare_calibration_data(calibration_tokens),
127130
)
128131

129-
# import copy
130-
# for layer in self.model.model.layers:
131-
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale)
132-
133132
def save_quantized(self, save_dir):
134133
save_quantized_model(
135134
self.model,

tests/test_auto_fp8.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import os
22
import shutil
33

4+
import pytest
45
import safetensors.torch
56
from transformers import AutoTokenizer
67

78
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
89

10+
MODELS = [
11+
"facebook/opt-125m",
12+
"Qwen/Qwen2-0.5B-Instruct",
13+
]
914

10-
def test_dynamic_quantization():
11-
model_id = "facebook/opt-125m"
12-
quantized_model_dir = "opt-125m-fp8-dynamic"
15+
@pytest.mark.parametrize("model_id", MODELS)
16+
def test_dynamic_quantization(model_id):
17+
quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic"
1318

1419
quantize_config = BaseQuantizeConfig(
1520
quant_method="fp8", activation_scheme="dynamic"
@@ -30,9 +35,9 @@ def test_dynamic_quantization():
3035
assert model_size < target_size
3136

3237

33-
def test_static_quantization():
34-
model_id = "facebook/opt-125m"
35-
quantized_model_dir = "opt-125m-fp8-static"
38+
@pytest.mark.parametrize("model_id", MODELS)
39+
def test_static_quantization(model_id):
40+
quantized_model_dir = model_id.split("/")[-1] + "-fp8-static"
3641

3742
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
3843
examples = ["auto-fp8 is an easy-to-use model quantization library"]
@@ -54,10 +59,9 @@ def test_static_quantization():
5459
target_size = 160 * (1024 * 1024)
5560
assert model_size < target_size
5661

57-
58-
def test_kv_cache_static_quantization():
59-
model_id = "facebook/opt-125m"
60-
quantized_model_dir = "opt-125m-fp8-static-kv"
62+
@pytest.mark.parametrize("model_id", MODELS)
63+
def test_kv_cache_static_quantization(model_id):
64+
quantized_model_dir = model_id.split("/")[-1] + "-fp8-static-kv"
6165

6266
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
6367
examples = ["auto-fp8 is an easy-to-use model quantization library"]

0 commit comments

Comments
 (0)