Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit efcf946

Browse files
authored
[Hardware][NV] Add support for ModelOpt static scaling checkpoints. (vllm-project#6112)
1 parent 1230263 commit efcf946

File tree

7 files changed

+258
-6
lines changed

7 files changed

+258
-6
lines changed

examples/fp8/quantizer/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
### Quantizer Utilities
2-
`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM:
3-
`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py`
2+
`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported
3+
from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py)
44

55
### Prerequisite
66

tests/models/test_modelopt.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# flake8: noqa
2+
"""Tests Model Optimizer fp8 models against ground truth generation
3+
Note: these tests will only pass on H100
4+
"""
5+
import os
6+
from typing import List
7+
8+
import pytest
9+
from transformers import AutoTokenizer
10+
11+
from tests.quantization.utils import is_quant_method_supported
12+
from vllm import LLM, SamplingParams
13+
14+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
15+
16+
MAX_MODEL_LEN = 1024
17+
18+
MODELS = ["nvidia/Llama-3.1-8B-Instruct-FP8"]
19+
20+
EXPECTED_STRS_MAP = {
21+
"nvidia/Llama-3.1-8B-Instruct-FP8": [
22+
"You're referring to VLLM, a high-performance Large Language Model (LLM) inference and",
23+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
24+
'The comparison between artificial intelligence (AI) and human intelligence in terms of processing information is a complex and',
25+
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
26+
'**The Spark of Imagination**\n\nZeta-5, a sleek and efficient robot, whir',
27+
'The COVID-19 pandemic has had a profound impact on global economic structures and business models, leading to',
28+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
29+
'Here are the translations:\n\n**Japanese:** 「早起きは早く獲物をとる'
30+
]
31+
}
32+
33+
34+
# This test compares against golden strings for exact match since
35+
# there is no baseline implementation to compare against
36+
# and is unstable w.r.t specifics of the fp8 implementation or
37+
# the hardware being run on.
38+
# Disabled to prevent it from breaking the build
39+
@pytest.mark.skip(
40+
reason=
41+
"Prevent unstable test based on golden strings from breaking the build.")
42+
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
43+
reason="fp8 is not supported on this GPU type.")
44+
@pytest.mark.parametrize("model_name", MODELS)
45+
def test_models(example_prompts, model_name) -> None:
46+
model = LLM(
47+
model=model_name,
48+
max_model_len=MAX_MODEL_LEN,
49+
trust_remote_code=True,
50+
enforce_eager=True,
51+
quantization="modelopt",
52+
)
53+
54+
tokenizer = AutoTokenizer.from_pretrained(model_name)
55+
formatted_prompts = [
56+
tokenizer.apply_chat_template([{
57+
"role": "user",
58+
"content": prompt
59+
}],
60+
tokenize=False,
61+
add_generation_prompt=True)
62+
for prompt in example_prompts
63+
]
64+
params = SamplingParams(max_tokens=20, temperature=0)
65+
generations: List[str] = []
66+
# Note: these need to be run 1 at a time due to numerical precision,
67+
# since the expected strs were generated this way.
68+
for prompt in formatted_prompts:
69+
outputs = model.generate(prompt, params)
70+
generations.append(outputs[0].outputs[0].text)
71+
del model
72+
73+
print(model_name, generations)
74+
expected_strs = EXPECTED_STRS_MAP[model_name]
75+
for i in range(len(example_prompts)):
76+
generated_str = generations[i]
77+
expected_str = expected_strs[i]
78+
assert expected_str == generated_str, (
79+
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")

vllm/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ def _verify_quantization(self) -> None:
282282
supported_quantization = [*QUANTIZATION_METHODS]
283283
rocm_supported_quantization = ["awq", "gptq", "fp8"]
284284
optimized_quantization_methods = [
285-
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
286-
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
287-
"experts_int8"
285+
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
286+
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
287+
"compressed-tensors", "experts_int8"
288288
]
289289
tpu_supported_quantization = ["tpu_int8"]
290290
neuron_supported_quantization = ["neuron_quant"]

vllm/model_executor/layers/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
2727
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
2828
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
29-
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod"
29+
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
30+
"ModelOptFp8LinearMethod"
3031
]
3132

3233

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
2323
GPTQMarlin24Config)
2424
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
25+
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
2526
from vllm.model_executor.layers.quantization.neuron_quant import (
2627
NeuronQuantConfig)
2728
from vllm.model_executor.layers.quantization.qqq import QQQConfig
@@ -34,6 +35,7 @@
3435
"tpu_int8": Int8TpuConfig,
3536
"fp8": Fp8Config,
3637
"fbgemm_fp8": FBGEMMFp8Config,
38+
"modelopt": ModelOptFp8Config,
3739
# The order of gptq methods is important for config.py iteration over
3840
# override_quantization_method(..)
3941
"marlin": MarlinConfig,
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from typing import Any, Dict, List, Optional
2+
3+
import torch
4+
from torch.nn import Module
5+
from torch.nn.parameter import Parameter
6+
7+
from vllm.logger import init_logger
8+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
9+
from vllm.model_executor.layers.quantization.base_config import (
10+
QuantizationConfig, QuantizeMethodBase)
11+
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
12+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
13+
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
14+
from vllm.model_executor.parameter import (ModelWeightParameter,
15+
PerTensorScaleParameter)
16+
17+
logger = init_logger(__name__)
18+
19+
ACTIVATION_SCHEMES = ["static"]
20+
21+
22+
class ModelOptFp8Config(QuantizationConfig):
23+
"""Config class for ModelOpt FP8."""
24+
25+
def __init__(
26+
self,
27+
is_checkpoint_fp8_serialized: bool = False,
28+
) -> None:
29+
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
30+
if is_checkpoint_fp8_serialized:
31+
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
32+
" the format is experimental and could change.")
33+
34+
@classmethod
35+
def get_name(cls) -> str:
36+
return "modelopt"
37+
38+
@classmethod
39+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
40+
return [torch.bfloat16, torch.half]
41+
42+
@classmethod
43+
def get_min_capability(cls) -> int:
44+
return 89
45+
46+
@classmethod
47+
def get_config_filenames(cls) -> List[str]:
48+
return ["hf_quant_config.json"]
49+
50+
@classmethod
51+
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
52+
quant_config = cls.get_from_keys(config, ["quantization"])
53+
quant_method = quant_config["quant_algo"]
54+
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
55+
if not is_checkpoint_fp8_serialized:
56+
raise ValueError("ModelOpt currently only supports static FP8"
57+
"quantization in vLLM. Please check the "
58+
"`hf_quant_config.json` file for your model's "
59+
"quant configuration.")
60+
return cls(is_checkpoint_fp8_serialized)
61+
62+
def get_quant_method(self, layer: torch.nn.Module,
63+
prefix: str) -> Optional["QuantizeMethodBase"]:
64+
from vllm.attention.layer import Attention # Avoid circular import
65+
if isinstance(layer, LinearBase):
66+
return ModelOptFp8LinearMethod(self)
67+
elif isinstance(layer, Attention):
68+
return ModelOptFp8KVCacheMethod(self)
69+
return None
70+
71+
def get_scaled_act_names(self) -> List[str]:
72+
return []
73+
74+
75+
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
76+
"""
77+
Supports loading kv-cache scaling factors from FP8 checkpoints.
78+
"""
79+
80+
def __init__(self, quant_config: ModelOptFp8Config):
81+
super().__init__(quant_config)
82+
83+
84+
class ModelOptFp8LinearMethod(LinearMethodBase):
85+
"""Linear method for Model Optimizer static quantization.
86+
Supports loading FP8 checkpoints with static weight scale and
87+
activation scale. Future support might be added for dynamic
88+
scales.
89+
90+
Limitations:
91+
1. Only support per-tensor quantization due to torch._scaled_mm support.
92+
2. Only support float8_e4m3fn datatype
93+
Args: quant_config: The ModelOpt quantization config.
94+
"""
95+
96+
def __init__(self, quant_config: ModelOptFp8Config):
97+
self.quant_config = quant_config
98+
self.cutlass_fp8_supported = cutlass_fp8_supported()
99+
100+
def create_weights(
101+
self,
102+
layer: torch.nn.Module,
103+
input_size_per_partition: int,
104+
output_partition_sizes: List[int],
105+
input_size: int,
106+
output_size: int,
107+
params_dtype: torch.dtype,
108+
**extra_weight_attrs,
109+
):
110+
del input_size, output_size
111+
output_size_per_partition = sum(output_partition_sizes)
112+
weight_loader = extra_weight_attrs.get("weight_loader")
113+
layer.logical_widths = output_partition_sizes
114+
layer.input_size_per_partition = input_size_per_partition
115+
layer.output_size_per_partition = output_size_per_partition
116+
weight_dtype = (torch.float8_e4m3fn
117+
if self.quant_config.is_checkpoint_fp8_serialized else
118+
params_dtype)
119+
weight = ModelWeightParameter(data=torch.empty(
120+
output_size_per_partition,
121+
input_size_per_partition,
122+
dtype=weight_dtype),
123+
input_dim=1,
124+
output_dim=0,
125+
weight_loader=weight_loader)
126+
layer.register_parameter("weight", weight)
127+
128+
if self.quant_config.is_checkpoint_fp8_serialized:
129+
# WEIGHT SCALE
130+
weight_scale = PerTensorScaleParameter(data=torch.empty(
131+
len(output_partition_sizes), dtype=torch.float32),
132+
weight_loader=weight_loader)
133+
weight_scale[:] = torch.finfo(torch.float32).min
134+
layer.register_parameter("weight_scale", weight_scale)
135+
# INPUT SCALE
136+
scale = PerTensorScaleParameter(data=torch.empty(
137+
len(output_partition_sizes), dtype=torch.float32),
138+
weight_loader=weight_loader)
139+
140+
scale[:] = torch.finfo(torch.float32).min
141+
layer.register_parameter("input_scale", scale)
142+
143+
def process_weights_after_loading(self, layer: Module) -> None:
144+
max_w_scale, weight = requantize_with_max_scale(
145+
layer.weight, layer.weight_scale, layer.logical_widths)
146+
layer.weight = Parameter(weight.t(), requires_grad=False)
147+
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
148+
layer.input_scale = Parameter(layer.input_scale.max(),
149+
requires_grad=False)
150+
151+
def apply(
152+
self,
153+
layer: torch.nn.Module,
154+
x: torch.Tensor,
155+
bias: Optional[torch.Tensor] = None,
156+
) -> torch.Tensor:
157+
return apply_fp8_linear(
158+
input=x,
159+
weight=layer.weight,
160+
weight_scale=layer.weight_scale,
161+
input_scale=layer.input_scale,
162+
bias=bias,
163+
cutlass_fp8_supported=self.cutlass_fp8_supported)

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,13 @@ def get_quant_config(model_config: ModelConfig,
192192

193193
if model_config.quantization == "bitsandbytes":
194194
config["adapter_name_or_path"] = model_name_or_path
195+
elif model_config.quantization == "modelopt":
196+
if config["producer"]["name"] == "modelopt":
197+
return quant_cls.from_config(config)
198+
else:
199+
raise ValueError(
200+
f"Unsupported quantization config"
201+
f" found for {model_config.quantization} in {f}.")
195202

196203
return quant_cls.from_config(config)
197204

0 commit comments

Comments
 (0)