15
15
from vllm .model_executor .layers .quantization .utils .quant_utils import (
16
16
is_layer_skipped )
17
17
from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
18
- apply_fp8_linear , create_per_channel_scale_param )
19
- from vllm .model_executor .utils import set_weight_attrs
18
+ apply_fp8_linear )
19
+ from vllm .model_executor .parameter import (ChannelQuantScaleParameter ,
20
+ ModelWeightParameter )
20
21
from vllm .platforms import current_platform
21
22
22
23
logger = init_logger (__name__ )
@@ -85,6 +86,7 @@ def create_weights(
85
86
params_dtype : torch .dtype ,
86
87
** extra_weight_attrs ,
87
88
):
89
+ weight_loader = extra_weight_attrs .get ("weight_loader" )
88
90
del input_size , output_size
89
91
output_size_per_partition = sum (output_partition_sizes )
90
92
@@ -95,20 +97,21 @@ def create_weights(
95
97
layer .orig_dtype = params_dtype
96
98
97
99
# WEIGHT
98
- weight = Parameter (torch .empty (output_size_per_partition ,
99
- input_size_per_partition ,
100
- dtype = torch .float8_e4m3fn ),
101
- requires_grad = False )
100
+ weight = ModelWeightParameter (data = torch .empty (
101
+ output_size_per_partition ,
102
+ input_size_per_partition ,
103
+ dtype = torch .float8_e4m3fn ),
104
+ input_dim = 1 ,
105
+ output_dim = 0 ,
106
+ weight_loader = weight_loader )
102
107
layer .register_parameter ("weight" , weight )
103
- set_weight_attrs (weight , {
104
- "input_dim" : 1 ,
105
- "output_dim" : 0 ,
106
- ** extra_weight_attrs ,
107
- })
108
108
109
109
# WEIGHT SCALE
110
- weight_scale = create_per_channel_scale_param (output_partition_sizes ,
111
- ** extra_weight_attrs )
110
+ weight_scale = ChannelQuantScaleParameter (data = torch .empty (
111
+ (sum (output_partition_sizes ), 1 ), dtype = torch .float32 ),
112
+ output_dim = 0 ,
113
+ weight_loader = weight_loader )
114
+ weight_scale [:] = torch .finfo (torch .float32 ).min
112
115
layer .register_parameter ("weight_scale" , weight_scale )
113
116
114
117
# INPUT SCALE UPPER BOUND
@@ -118,6 +121,11 @@ def create_weights(
118
121
layer .input_scale_ub = input_scale_ub
119
122
120
123
def process_weights_after_loading (self , layer : Module ) -> None :
124
+ # required by torch.compile
125
+ layer .weight_scale = Parameter (layer .weight_scale .data ,
126
+ requires_grad = False )
127
+ layer .weight = Parameter (layer .weight .data , requires_grad = False )
128
+
121
129
weight = layer .weight
122
130
layer .weight = Parameter (weight .t (), requires_grad = False )
123
131
0 commit comments