Skip to content

Commit e599d43

Browse files
committed
unit test
Signed-off-by: weimingc <[email protected]>
1 parent d1c5d19 commit e599d43

File tree

2 files changed

+102
-1
lines changed

2 files changed

+102
-1
lines changed

examples/vllm_serve/vllm_serve_fakequant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def disable_compilation(model):
9797
quant_config: dict[str, Any] = {
9898
"quant_dataset": "cnn_dailymail",
9999
"quant_num_samples": 512,
100-
"quant_format": "NVFP4_DEFAULT_CFG",
100+
# "quant_format": "NVFP4_DEFAULT_CFG",
101+
"quant_format": "NVFP4_AWQ_LITE_CFG",
101102
"amax_file_path": None, # Optional: path to pre-computed amax values (e.g., "/path/to/amax.pt")
102103
}
103104

@@ -176,6 +177,7 @@ def calibrate_loop(model: Any = None) -> None:
176177

177178
quant_cfg = getattr(mtq, quant_config["quant_format"])
178179

180+
print(f"Quantizing model with {quant_config['quant_format']} format")
179181
with disable_compilation(self.model):
180182
mtq.quantize(self.model, quant_cfg, forward_loop=calibrate_loop)
181183

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
import torch
18+
19+
pytest.importorskip("transformers")
20+
21+
from transformers import LlamaConfig, LlamaForCausalLM
22+
23+
import modelopt.torch.quantization as mtq
24+
from modelopt.torch.export.quant_utils import pattern_fuse_prequant
25+
26+
27+
def get_tiny_llama(attention_heads=4, key_value_heads=4):
28+
"""Create a tiny Llama model for testing."""
29+
config = LlamaConfig(
30+
hidden_size=64,
31+
intermediate_size=128,
32+
num_hidden_layers=2,
33+
num_attention_heads=attention_heads,
34+
num_key_value_heads=key_value_heads,
35+
max_position_embeddings=128,
36+
vocab_size=256,
37+
)
38+
return LlamaForCausalLM(config)
39+
40+
41+
@pytest.mark.parametrize(
42+
"quant_config",
43+
[
44+
mtq.INT4_AWQ_CFG,
45+
mtq.NVFP4_AWQ_LITE_CFG,
46+
],
47+
)
48+
@pytest.mark.parametrize(
49+
"attention_kv_heads_pair",
50+
[
51+
(4, 4), # MHA
52+
(4, 2), # GQA
53+
(4, 1), # MQA
54+
],
55+
)
56+
def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair):
57+
"""Test pattern_fuse_prequant on modules from a tiny Llama model."""
58+
model = get_tiny_llama(attention_kv_heads_pair[0], attention_kv_heads_pair[1]).to("cuda")
59+
60+
# Quantize the model
61+
dummy_input = torch.randint(0, 256, (1, 16), device="cuda")
62+
mtq.quantize(model, quant_config, lambda m: m(dummy_input))
63+
64+
# Run forward pass before fusion
65+
model.eval()
66+
with torch.no_grad():
67+
output_before_fuse = model(dummy_input)
68+
69+
traget_module_name_list = [
70+
"model.layers.0.self_attn.o_proj",
71+
"model.layers.0.mlp.down_proj",
72+
"model.layers.1.self_attn.o_proj",
73+
"model.layers.1.mlp.down_proj",
74+
]
75+
76+
# Apply fusion
77+
pattern_fuse_prequant(model)
78+
79+
# Check if pre_quant_scale and fused_with_prequant flag are removed correctly
80+
for target_module_name in traget_module_name_list:
81+
target_module = model.get_submodule(target_module_name)
82+
83+
# Verify pre_quant_scale was removed
84+
assert not hasattr(target_module.input_quantizer, "_pre_quant_scale"), (
85+
f"{target_module_name}: pre_quant_scale should be removed after fusion"
86+
)
87+
88+
# Verify fused_with_prequant flag was set
89+
assert (
90+
hasattr(target_module, "fused_with_prequant") and target_module.fused_with_prequant
91+
), f"{target_module_name}: fused_with_prequant flag should be set"
92+
93+
# Verify output is close to the original output
94+
with torch.no_grad():
95+
output_after_fuse = model(dummy_input)
96+
# There will be some small difference due to quantization errors after pre_quant_scale fusion to the weights
97+
assert torch.allclose(
98+
output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1
99+
), "Output should be the same before and after fusion"

0 commit comments

Comments
 (0)