@@ -33,7 +33,6 @@ def cleanup_memory():
33
33
34
34
def per_tensor_quantize (tensor : torch .Tensor ) -> Tuple [torch .Tensor , float ]:
35
35
"""Quantize a tensor using per-tensor static scaling factor.
36
-
37
36
Args:
38
37
tensor: The input tensor.
39
38
"""
@@ -83,11 +82,12 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
83
82
84
83
85
84
class FP8StaticLinearQuantizer (torch .nn .Module ):
86
- def __init__ (self , qweight , weight_scale ):
85
+ def __init__ (self , qweight , weight_scale , bias ):
87
86
super ().__init__ ()
88
87
self .weight = torch .nn .Parameter (qweight , requires_grad = False )
89
88
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
90
89
self .act_scale = None
90
+ self .bias = bias
91
91
92
92
def forward (self , x ):
93
93
# Dynamically quantize
@@ -105,18 +105,19 @@ def forward(self, x):
105
105
A_scale = self .act_scale ,
106
106
B = self .weight ,
107
107
B_scale = self .weight_scale ,
108
- bias = None ,
108
+ bias = self . bias ,
109
109
out_dtype = x .dtype ,
110
110
)
111
111
return output
112
112
113
113
114
114
class FP8StaticLinear (torch .nn .Module ):
115
- def __init__ (self , qweight , weight_scale , act_scale = 0.0 ):
115
+ def __init__ (self , qweight , weight_scale , bias , act_scale = 0.0 ):
116
116
super ().__init__ ()
117
117
self .weight = torch .nn .Parameter (qweight , requires_grad = False )
118
118
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
119
119
self .act_scale = torch .nn .Parameter (act_scale , requires_grad = False )
120
+ self .bias = bias
120
121
121
122
def per_tensor_quantize (
122
123
self , tensor : torch .Tensor , inv_scale : float
@@ -135,17 +136,18 @@ def forward(self, x):
135
136
A_scale = self .act_scale ,
136
137
B = self .weight ,
137
138
B_scale = self .weight_scale ,
138
- bias = None ,
139
+ bias = self . bias ,
139
140
out_dtype = x .dtype ,
140
141
)
141
142
return output
142
143
143
144
144
145
class FP8DynamicLinear (torch .nn .Module ):
145
- def __init__ (self , qweight , scale ):
146
+ def __init__ (self , qweight , scale , bias ):
146
147
super ().__init__ ()
147
148
self .weight = torch .nn .Parameter (qweight , requires_grad = False )
148
149
self .weight_scale = torch .nn .Parameter (scale , requires_grad = False )
150
+ self .bias = bias
149
151
150
152
def forward (self , x ):
151
153
qinput , x_scale = per_tensor_quantize (x )
@@ -154,7 +156,7 @@ def forward(self, x):
154
156
A_scale = x_scale ,
155
157
B = self .weight ,
156
158
B_scale = self .weight_scale ,
157
- bias = None ,
159
+ bias = self . bias ,
158
160
out_dtype = x .dtype ,
159
161
)
160
162
return output
@@ -178,7 +180,7 @@ def quantize_weights(model):
178
180
if not isinstance (linear , torch .nn .Linear ):
179
181
continue
180
182
quant_weight , quant_scale = per_tensor_quantize (linear .weight )
181
- quant_linear = FP8DynamicLinear (quant_weight , quant_scale )
183
+ quant_linear = FP8DynamicLinear (quant_weight , quant_scale , linear . bias )
182
184
replace_module (model , name , quant_linear )
183
185
del linear
184
186
cleanup_memory ()
@@ -191,7 +193,7 @@ def quantize_activations(model, calibration_tokens):
191
193
if not isinstance (dynamic_quant_linear , FP8DynamicLinear ):
192
194
continue
193
195
quantizer = FP8StaticLinearQuantizer (
194
- dynamic_quant_linear .weight , dynamic_quant_linear .weight_scale
196
+ dynamic_quant_linear .weight , dynamic_quant_linear .weight_scale , dynamic_quant_linear . bias
195
197
)
196
198
replace_module (model , name , quantizer )
197
199
del dynamic_quant_linear
@@ -210,14 +212,15 @@ def quantize_activations(model, calibration_tokens):
210
212
if not isinstance (quantizer , FP8StaticLinearQuantizer ):
211
213
continue
212
214
static_proj = FP8StaticLinear (
213
- quantizer .weight , quantizer .weight_scale , quantizer .act_scale
215
+ quantizer .weight , quantizer .weight_scale , quantizer .bias , quantizer . act_scale
214
216
)
215
217
replace_module (model , name , static_proj )
216
218
del quantizer
217
219
cleanup_memory ()
218
220
219
221
220
222
def save_quantized_model (model , activation_scheme , save_dir ):
223
+ print (model )
221
224
print (f"Saving the model to { save_dir } " )
222
225
static_q_dict = {
223
226
"quantization_config" : {
0 commit comments