Skip to content

Commit 852e506

Browse files
authored
[https://nvbugs/5558117][fix] Allow per-layer quant config from hf_quant_config.json (#8617)
Signed-off-by: Anthony Chang <[email protected]>
1 parent 98453d2 commit 852e506

File tree

3 files changed

+147
-25
lines changed

3 files changed

+147
-25
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ class ModelConfig(Generic[TConfig]):
113113
pretrained_config: Optional[TConfig] = None
114114
mapping: Mapping = field(default_factory=Mapping)
115115

116-
# quantization configs
116+
# Quantization configs
117117
quant_config: QuantConfig = field(default_factory=QuantConfig)
118-
# TODO(qijun): support per linear layer quantization
118+
# Per linear layer quantization in quant_cfg.json or hf_quant_config.json
119119
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
120120
# Delay weights creation to DecoderModelForCausalLM.__post_init__
121121
# to support mixed quantization.
@@ -278,28 +278,41 @@ def load_modelopt_quant_config(quant_config_file, checkpoint_dir,
278278
'exclude_modules', None)
279279

280280
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
281-
mixed_quant_config_file = transformers.utils.hub.cached_file(
282-
checkpoint_dir, 'quant_cfg.json')
283-
with open(mixed_quant_config_file) as fm:
284-
mixed_quant_configs = json.load(fm)
285-
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
286-
kv_cache_quant_algo = mixed_quant_configs['kv_cache_quant_algo']
287-
mixed_quant_configs = mixed_quant_configs['quantized_layers']
288-
if kv_cache_quant_algo is not None and quant_config.kv_cache_quant_algo is not None:
289-
if kv_cache_quant_algo != quant_config.kv_cache_quant_algo:
290-
raise RuntimeError(
291-
f"The kvcache config in 'quant_cfg.json', {kv_cache_quant_algo},"
292-
f"is different from 'hf_quant_config.json', {quant_config.kv_cache_quant_algo}!"
293-
)
294-
kv_cache_quant_algo = kv_cache_quant_algo or quant_config.kv_cache_quant_algo
295-
296-
for layer in mixed_quant_configs:
297-
config = QuantConfig()
298-
config.kv_cache_quant_algo = kv_cache_quant_algo
299-
config.quant_algo = mixed_quant_configs[layer]['quant_algo']
300-
config.group_size = mixed_quant_configs[layer].get(
301-
'group_size', None)
302-
mixed_quant_configs[layer] = config
281+
json_extended_quant_configs: dict = {}
282+
# See tests/unittest/llmapi/test_llm_quant.py
283+
try:
284+
mixed_quant_config_file = transformers.utils.hub.cached_file(
285+
checkpoint_dir, 'quant_cfg.json')
286+
with open(mixed_quant_config_file) as fm:
287+
json_extended_quant_configs = json.load(fm)
288+
except Exception:
289+
logger.info(
290+
f"No quant_cfg.json found for layer quant info, using hf_quant_config.json."
291+
)
292+
json_quant_configs.update(json_extended_quant_configs)
293+
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
294+
kv_cache_quant_algo = json_quant_configs.get(
295+
'kv_cache_quant_algo', None)
296+
mixed_quant_configs = json_quant_configs.get(
297+
'quantized_layers', None)
298+
if (kv_quant_lhs := json_extended_quant_configs.get(
299+
"kv_cache_quant_algo", None)) is not None and (
300+
kv_quant_rhs :=
301+
quant_config.kv_cache_quant_algo) is not None:
302+
if kv_quant_lhs != kv_quant_rhs:
303+
raise RuntimeError(
304+
f"The kvcache config in 'quant_cfg.json', {kv_quant_lhs},"
305+
f"is different from 'hf_quant_config.json', {kv_quant_rhs}!"
306+
)
307+
quant_config.kv_cache_quant_algo = json_quant_configs[
308+
"kv_cache_quant_algo"]
309+
for layer in mixed_quant_configs:
310+
config = QuantConfig()
311+
config.kv_cache_quant_algo = kv_cache_quant_algo
312+
config.quant_algo = mixed_quant_configs[layer]['quant_algo']
313+
config.group_size = mixed_quant_configs[layer].get(
314+
'group_size', None)
315+
mixed_quant_configs[layer] = config
303316
layer_quant_config = mixed_quant_configs
304317
elif quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
305318
if quant_config.group_size is None:
@@ -459,6 +472,9 @@ def cached_file(path_or_repo_id, file_name):
459472
except OSError:
460473
return None
461474

475+
# Some checkpoints lack torch_dtype, populate with dtype
476+
pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype',
477+
None)
462478
quant_config = QuantConfig()
463479
layer_quant_config = None
464480
moe_backend = kwargs.get('moe_backend', 'CUTLASS')

tensorrt_llm/llmapi/llm_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ def _update_from_hf_quant_config(self) -> bool:
390390
)
391391

392392
for key, value in hf_quant_config.items():
393-
logger.info(f"Setting {key}={value} from HF quant config.")
393+
logger.info(
394+
f"Setting {key}={str(value)[:100]}{'...' if len(str(value)) > 100 else ''} from HF quant config."
395+
)
394396
setattr(quant_config, key, value)
395397

396398
# Update the quant_config in llm_args for pytorch

tests/unittest/llmapi/test_llm_quant.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import json
2+
import tempfile
3+
from pathlib import Path
4+
15
import pytest
26

37
from tensorrt_llm._tensorrt_engine import LLM
8+
from tensorrt_llm._torch.model_config import ModelConfig
49
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
510
from tensorrt_llm.llmapi.llm_utils import CalibConfig, QuantAlgo, QuantConfig
611

@@ -71,6 +76,105 @@ def test_llm_fp8_quantization_modelOpt_ckpt():
7176
assert output.outputs[0].text == " D E F G H I"
7277

7378

79+
def test_quant_cfg_from_quant_cfg_json():
80+
"""
81+
Test loading MIXED_PRECISION config from quant_cfg.json with per-layer quantization.
82+
This supports the workflow from examples/quantization/quantize_mixed_precision_moe.py.
83+
"""
84+
with tempfile.TemporaryDirectory() as tmp_dir:
85+
model_dir = Path(tmp_dir)
86+
87+
# Create dummy quant_cfg.json
88+
quant_cfg_content = {
89+
"quant_algo": "MIXED_PRECISION",
90+
"kv_cache_quant_algo": "FP8",
91+
"quantized_layers": {
92+
"model.layers.0.self_attn.q_proj": {
93+
"quant_algo": "FP8"
94+
},
95+
"model.layers.0.self_attn.k_proj": {
96+
"quant_algo": "FP8"
97+
},
98+
"model.layers.1.mlp.gate_proj": {
99+
"quant_algo": "W4A8_AWQ",
100+
"group_size": 128
101+
}
102+
}
103+
}
104+
105+
quant_cfg_file = model_dir / "quant_cfg.json"
106+
with open(quant_cfg_file, 'w') as f:
107+
json.dump(quant_cfg_content, f)
108+
109+
# Create dummy hf_quant_config.json
110+
hf_quant_config_content = {
111+
"quantization": {
112+
"quant_algo": "MIXED_PRECISION",
113+
"kv_cache_quant_algo": None,
114+
}
115+
}
116+
117+
hf_quant_config_file = model_dir / "hf_quant_config.json"
118+
with open(hf_quant_config_file, 'w') as f:
119+
json.dump(hf_quant_config_content, f)
120+
121+
quant_config, layer_quant_config = ModelConfig.load_modelopt_quant_config(
122+
hf_quant_config_file, model_dir, None)
123+
124+
# Verify quant_cfg.json was loaded
125+
assert quant_config.quant_algo == QuantAlgo.MIXED_PRECISION
126+
assert quant_config.kv_cache_quant_algo == "FP8"
127+
128+
# Verify layer configs were created correctly
129+
assert layer_quant_config[
130+
"model.layers.0.self_attn.q_proj"].quant_algo == "FP8"
131+
assert layer_quant_config[
132+
"model.layers.0.self_attn.q_proj"].kv_cache_quant_algo == "FP8"
133+
assert layer_quant_config[
134+
"model.layers.1.mlp.gate_proj"].quant_algo == "W4A8_AWQ"
135+
assert layer_quant_config[
136+
"model.layers.1.mlp.gate_proj"].group_size == 128
137+
138+
139+
def test_quant_cfg_from_hf_quant_config():
140+
"""Test fallback to hf_quant_config.json when quant_cfg.json is missing."""
141+
with tempfile.TemporaryDirectory() as tmp_dir:
142+
model_dir = Path(tmp_dir)
143+
144+
# Create dummy hf_quant_config.json
145+
hf_quant_config_content = {
146+
"quantization": {
147+
"quant_algo": "MIXED_PRECISION",
148+
"kv_cache_quant_algo": "FP8",
149+
"quantized_layers": {
150+
"model.layers.0.self_attn.q_proj": {
151+
"quant_algo": "FP8"
152+
},
153+
"model.layers.0.mlp.up_proj": {
154+
"quant_algo": "W4A16_AWQ",
155+
"group_size": 64
156+
}
157+
}
158+
}
159+
}
160+
hf_quant_config_file = model_dir / "hf_quant_config.json"
161+
with open(hf_quant_config_file, 'w') as f:
162+
json.dump(hf_quant_config_content, f)
163+
quant_config, layer_quant_config = ModelConfig.load_modelopt_quant_config(
164+
hf_quant_config_file, model_dir, None)
165+
166+
# Verify layer configs
167+
assert quant_config.quant_algo == QuantAlgo.MIXED_PRECISION
168+
assert quant_config.kv_cache_quant_algo == "FP8"
169+
assert layer_quant_config[
170+
"model.layers.0.self_attn.q_proj"].quant_algo == "FP8"
171+
assert layer_quant_config[
172+
"model.layers.0.mlp.up_proj"].quant_algo == "W4A16_AWQ"
173+
assert layer_quant_config["model.layers.0.mlp.up_proj"].group_size == 64
174+
175+
74176
if __name__ == "__main__":
75177
test_llm_int4_awq_quantization()
76178
test_llm_fp8_quantization_modelOpt_ckpt()
179+
test_quant_cfg_from_quant_cfg_json()
180+
test_quant_cfg_from_hf_quant_config()

0 commit comments

Comments
 (0)