Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit e16fa99

Browse files
dsikkamgoin
andauthored
[Misc] Update fbgemmfp8 to use vLLMParameters (vllm-project#7972)
Co-authored-by: Michael Goin <[email protected]>
1 parent 61f4a93 commit e16fa99

File tree

3 files changed

+22
-41
lines changed

3 files changed

+22
-41
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
2727
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
2828
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
29-
"TPUInt8LinearMethod", "GPTQLinearMethod"
29+
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod"
3030
]
3131

3232

vllm/model_executor/layers/quantization/fbgemm_fp8.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1616
is_layer_skipped)
1717
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
18-
apply_fp8_linear, create_per_channel_scale_param)
19-
from vllm.model_executor.utils import set_weight_attrs
18+
apply_fp8_linear)
19+
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
20+
ModelWeightParameter)
2021
from vllm.platforms import current_platform
2122

2223
logger = init_logger(__name__)
@@ -85,6 +86,7 @@ def create_weights(
8586
params_dtype: torch.dtype,
8687
**extra_weight_attrs,
8788
):
89+
weight_loader = extra_weight_attrs.get("weight_loader")
8890
del input_size, output_size
8991
output_size_per_partition = sum(output_partition_sizes)
9092

@@ -95,20 +97,21 @@ def create_weights(
9597
layer.orig_dtype = params_dtype
9698

9799
# WEIGHT
98-
weight = Parameter(torch.empty(output_size_per_partition,
99-
input_size_per_partition,
100-
dtype=torch.float8_e4m3fn),
101-
requires_grad=False)
100+
weight = ModelWeightParameter(data=torch.empty(
101+
output_size_per_partition,
102+
input_size_per_partition,
103+
dtype=torch.float8_e4m3fn),
104+
input_dim=1,
105+
output_dim=0,
106+
weight_loader=weight_loader)
102107
layer.register_parameter("weight", weight)
103-
set_weight_attrs(weight, {
104-
"input_dim": 1,
105-
"output_dim": 0,
106-
**extra_weight_attrs,
107-
})
108108

109109
# WEIGHT SCALE
110-
weight_scale = create_per_channel_scale_param(output_partition_sizes,
111-
**extra_weight_attrs)
110+
weight_scale = ChannelQuantScaleParameter(data=torch.empty(
111+
(sum(output_partition_sizes), 1), dtype=torch.float32),
112+
output_dim=0,
113+
weight_loader=weight_loader)
114+
weight_scale[:] = torch.finfo(torch.float32).min
112115
layer.register_parameter("weight_scale", weight_scale)
113116

114117
# INPUT SCALE UPPER BOUND
@@ -118,6 +121,11 @@ def create_weights(
118121
layer.input_scale_ub = input_scale_ub
119122

120123
def process_weights_after_loading(self, layer: Module) -> None:
124+
# required by torch.compile
125+
layer.weight_scale = Parameter(layer.weight_scale.data,
126+
requires_grad=False)
127+
layer.weight = Parameter(layer.weight.data, requires_grad=False)
128+
121129
weight = layer.weight
122130
layer.weight = Parameter(weight.t(), requires_grad=False)
123131

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from typing import List, Optional, Tuple, Union
22

33
import torch
4-
from torch.nn import Parameter
54

65
from vllm import _custom_ops as ops
7-
from vllm.model_executor.utils import set_weight_attrs
86
from vllm.platforms import current_platform
97
from vllm.utils import is_hip
108

@@ -38,31 +36,6 @@ def all_close_1d(x: torch.Tensor) -> bool:
3836
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
3937

4038

41-
def create_per_tensor_scale_param(
42-
output_partition_sizes: List[int],
43-
**extra_weight_attrs,
44-
) -> Parameter:
45-
scale = Parameter(torch.empty(len(output_partition_sizes),
46-
dtype=torch.float32),
47-
requires_grad=False)
48-
scale[:] = torch.finfo(torch.float32).min
49-
set_weight_attrs(scale, {
50-
"needs_scalar_to_array": True,
51-
**extra_weight_attrs
52-
})
53-
return scale
54-
55-
56-
def create_per_channel_scale_param(output_partition_sizes: List[int],
57-
**extra_weight_attrs) -> Parameter:
58-
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
59-
dtype=torch.float32),
60-
requires_grad=False)
61-
scale[:] = torch.finfo(torch.float32).min
62-
set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
63-
return scale
64-
65-
6639
def convert_to_channelwise(
6740
weight_scale: torch.Tensor,
6841
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)