Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 955b519

Browse files
authored
[Misc] update fp8 to use vLLMParameter (vllm-project#7437)
1 parent 55d63b1 commit 955b519

File tree

4 files changed

+51
-17
lines changed

4 files changed

+51
-17
lines changed

tests/weight_loading/models.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
1515
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
1616
awq, casperhansen/mixtral-instruct-awq, main
1717
awq_marlin, casperhansen/mixtral-instruct-awq, main
18+
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main

vllm/model_executor/layers/linear.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
WEIGHT_LOADER_V2_SUPPORTED = [
2424
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
25-
"AWQLinearMethod", "GPTQMarlinLinearMethod"
25+
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod"
2626
]
2727

2828

@@ -349,6 +349,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
349349
param_data.copy_(loaded_weight)
350350

351351
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
352+
# Special case for loading scales off disk, which often do not
353+
# have a shape (such as in the case of AutoFP8).
354+
if len(loaded_weight.shape) == 0:
355+
assert loaded_weight.numel() == 1
356+
loaded_weight = loaded_weight.reshape(1)
352357
param.load_column_parallel_weight(loaded_weight=loaded_weight)
353358

354359
def forward(self, input_):
@@ -1021,6 +1026,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
10211026

10221027
def weight_loader_v2(self, param: BasevLLMParameter,
10231028
loaded_weight: torch.Tensor):
1029+
1030+
# Special case for loading scales off disk, which often do not
1031+
# have a shape (such as in the case of AutoFP8).
1032+
if len(loaded_weight.shape) == 0:
1033+
assert loaded_weight.numel() == 1
1034+
loaded_weight = loaded_weight.reshape(1)
1035+
10241036
param.load_row_parallel_weight(loaded_weight=loaded_weight)
10251037

10261038
def forward(self, input_):

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
is_layer_skipped)
2020
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
2121
all_close_1d, apply_fp8_linear, convert_to_channelwise,
22-
create_per_tensor_scale_param, cutlass_fp8_supported,
23-
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
22+
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
2423
requantize_with_max_scale)
24+
from vllm.model_executor.parameter import (ModelWeightParameter,
25+
PerTensorScaleParameter)
2526
from vllm.model_executor.utils import set_weight_attrs
2627
from vllm.platforms import current_platform
2728
from vllm.utils import is_hip, print_warning_once
@@ -137,6 +138,7 @@ def create_weights(
137138
):
138139
del input_size, output_size
139140
output_size_per_partition = sum(output_partition_sizes)
141+
weight_loader = extra_weight_attrs.get("weight_loader")
140142

141143
layer.logical_widths = output_partition_sizes
142144

@@ -148,34 +150,41 @@ def create_weights(
148150
weight_dtype = (torch.float8_e4m3fn
149151
if self.quant_config.is_checkpoint_fp8_serialized else
150152
params_dtype)
151-
weight = Parameter(torch.empty(output_size_per_partition,
152-
input_size_per_partition,
153-
dtype=weight_dtype),
154-
requires_grad=False)
153+
154+
weight = ModelWeightParameter(data=torch.empty(
155+
output_size_per_partition,
156+
input_size_per_partition,
157+
dtype=weight_dtype),
158+
input_dim=1,
159+
output_dim=0,
160+
weight_loader=weight_loader)
155161
layer.register_parameter("weight", weight)
156-
set_weight_attrs(weight, {
157-
**extra_weight_attrs,
158-
"input_dim": 1,
159-
"output_dim": 0,
160-
})
161162

162163
# If checkpoint is serialized fp8, load them.
163164
# Otherwise, wait until process_weights_after_loading.
164165
if self.quant_config.is_checkpoint_fp8_serialized:
165166
# WEIGHT SCALE
166-
scale = create_per_tensor_scale_param(output_partition_sizes,
167-
**extra_weight_attrs)
167+
scale = PerTensorScaleParameter(data=torch.empty(
168+
len(output_partition_sizes), dtype=torch.float32),
169+
weight_loader=weight_loader)
170+
171+
scale[:] = torch.finfo(torch.float32).min
168172
layer.register_parameter("weight_scale", scale)
169173

170174
# INPUT ACTIVATION SCALE
171175
if self.quant_config.activation_scheme == "static":
172-
scale = create_per_tensor_scale_param(output_partition_sizes,
173-
**extra_weight_attrs)
176+
scale = PerTensorScaleParameter(data=torch.empty(
177+
len(output_partition_sizes), dtype=torch.float32),
178+
weight_loader=weight_loader)
179+
180+
scale[:] = torch.finfo(torch.float32).min
174181
layer.register_parameter("input_scale", scale)
175182
else:
176183
layer.register_parameter("input_scale", None)
177184

178185
def process_weights_after_loading(self, layer: Module) -> None:
186+
layer.weight = torch.nn.Parameter(layer.weight.data,
187+
requires_grad=False)
179188
# If checkpoint not serialized fp8, quantize the weights.
180189
if not self.quant_config.is_checkpoint_fp8_serialized:
181190
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
@@ -197,6 +206,11 @@ def process_weights_after_loading(self, layer: Module) -> None:
197206
# If checkpoint is fp8, handle that there are N scales for N
198207
# shards in a fused module
199208
else:
209+
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
210+
requires_grad=False)
211+
if self.quant_config.activation_scheme == "static":
212+
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
213+
requires_grad=False)
200214
# If using marlin (w8a16), kernel uses channelwise weights,
201215
# so extend the weight scales to be channelwise.
202216
if self.use_marlin:

vllm/model_executor/parameter.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,25 @@ def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
208208
if isinstance(shard_id, int):
209209
return shard_id
210210

211+
# if not int, assume shard_id for qkv
212+
# map to int and return
211213
assert isinstance(shard_id, str)
212214
assert shard_id in self.qkv_idxs
213215
return self.qkv_idxs[shard_id]
214216

217+
# For row parallel layers, no sharding needed
218+
# load weight into parameter as is
219+
def load_row_parallel_weight(self, *args, **kwargs):
220+
super().load_row_parallel_weight(*args, **kwargs)
221+
215222
def load_merged_column_weight(self, *args, **kwargs):
216223
self._load_into_shard_id(*args, **kwargs)
217224

218225
def load_qkv_weight(self, *args, **kwargs):
219226
self._load_into_shard_id(*args, **kwargs)
220227

221228
def load_column_parallel_weight(self, *args, **kwargs):
222-
self._load_into_shard_id(*args, **kwargs)
229+
super().load_row_parallel_weight(*args, **kwargs)
223230

224231
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
225232
shard_id: Union[str, int], **kwargs):

0 commit comments

Comments
 (0)