19
19
is_layer_skipped )
20
20
from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
21
21
all_close_1d , apply_fp8_linear , convert_to_channelwise ,
22
- create_per_tensor_scale_param , cutlass_fp8_supported ,
23
- normalize_e4m3fn_to_e4m3fnuz , per_tensor_dequantize ,
22
+ cutlass_fp8_supported , normalize_e4m3fn_to_e4m3fnuz , per_tensor_dequantize ,
24
23
requantize_with_max_scale )
24
+ from vllm .model_executor .parameter import (ModelWeightParameter ,
25
+ PerTensorScaleParameter )
25
26
from vllm .model_executor .utils import set_weight_attrs
26
27
from vllm .platforms import current_platform
27
28
from vllm .utils import is_hip , print_warning_once
@@ -137,6 +138,7 @@ def create_weights(
137
138
):
138
139
del input_size , output_size
139
140
output_size_per_partition = sum (output_partition_sizes )
141
+ weight_loader = extra_weight_attrs .get ("weight_loader" )
140
142
141
143
layer .logical_widths = output_partition_sizes
142
144
@@ -148,34 +150,41 @@ def create_weights(
148
150
weight_dtype = (torch .float8_e4m3fn
149
151
if self .quant_config .is_checkpoint_fp8_serialized else
150
152
params_dtype )
151
- weight = Parameter (torch .empty (output_size_per_partition ,
152
- input_size_per_partition ,
153
- dtype = weight_dtype ),
154
- requires_grad = False )
153
+
154
+ weight = ModelWeightParameter (data = torch .empty (
155
+ output_size_per_partition ,
156
+ input_size_per_partition ,
157
+ dtype = weight_dtype ),
158
+ input_dim = 1 ,
159
+ output_dim = 0 ,
160
+ weight_loader = weight_loader )
155
161
layer .register_parameter ("weight" , weight )
156
- set_weight_attrs (weight , {
157
- ** extra_weight_attrs ,
158
- "input_dim" : 1 ,
159
- "output_dim" : 0 ,
160
- })
161
162
162
163
# If checkpoint is serialized fp8, load them.
163
164
# Otherwise, wait until process_weights_after_loading.
164
165
if self .quant_config .is_checkpoint_fp8_serialized :
165
166
# WEIGHT SCALE
166
- scale = create_per_tensor_scale_param (output_partition_sizes ,
167
- ** extra_weight_attrs )
167
+ scale = PerTensorScaleParameter (data = torch .empty (
168
+ len (output_partition_sizes ), dtype = torch .float32 ),
169
+ weight_loader = weight_loader )
170
+
171
+ scale [:] = torch .finfo (torch .float32 ).min
168
172
layer .register_parameter ("weight_scale" , scale )
169
173
170
174
# INPUT ACTIVATION SCALE
171
175
if self .quant_config .activation_scheme == "static" :
172
- scale = create_per_tensor_scale_param (output_partition_sizes ,
173
- ** extra_weight_attrs )
176
+ scale = PerTensorScaleParameter (data = torch .empty (
177
+ len (output_partition_sizes ), dtype = torch .float32 ),
178
+ weight_loader = weight_loader )
179
+
180
+ scale [:] = torch .finfo (torch .float32 ).min
174
181
layer .register_parameter ("input_scale" , scale )
175
182
else :
176
183
layer .register_parameter ("input_scale" , None )
177
184
178
185
def process_weights_after_loading (self , layer : Module ) -> None :
186
+ layer .weight = torch .nn .Parameter (layer .weight .data ,
187
+ requires_grad = False )
179
188
# If checkpoint not serialized fp8, quantize the weights.
180
189
if not self .quant_config .is_checkpoint_fp8_serialized :
181
190
qweight , weight_scale = ops .scaled_fp8_quant (layer .weight ,
@@ -197,6 +206,11 @@ def process_weights_after_loading(self, layer: Module) -> None:
197
206
# If checkpoint is fp8, handle that there are N scales for N
198
207
# shards in a fused module
199
208
else :
209
+ layer .weight_scale = torch .nn .Parameter (layer .weight_scale .data ,
210
+ requires_grad = False )
211
+ if self .quant_config .activation_scheme == "static" :
212
+ layer .input_scale = torch .nn .Parameter (layer .input_scale .data ,
213
+ requires_grad = False )
200
214
# If using marlin (w8a16), kernel uses channelwise weights,
201
215
# so extend the weight scales to be channelwise.
202
216
if self .use_marlin :
0 commit comments