@@ -72,11 +72,19 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
72
72
# Deal with empty tensors (triggeted by empty MoE experts)
73
73
return torch .empty (size = (0 , B .shape [0 ]), dtype = out_dtype , device = A .device )
74
74
75
+ < << << << HEAD
75
76
# TODO: Disable native fp8 gemm for now, always just dequantize
76
77
# native_fp8_support = (
77
78
# torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
78
79
# )
79
80
native_fp8_support = False
81
+ == == == =
82
+ native_fp8_support = (
83
+ torch .cuda .is_available ()
84
+ and torch .cuda .get_device_capability () >= (8 , 9 )
85
+ and False
86
+ )
87
+ >> >> >> > 3 ee9283 (Support calibrating kv cache scales )
80
88
if native_fp8_support :
81
89
need_reshape = A .dim () == 3
82
90
if need_reshape :
@@ -108,6 +116,7 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
108
116
109
117
# Class responsible for quantizing weights
110
118
class FP8DynamicLinear (torch .nn .Module ):
119
+ < << << << HEAD
111
120
def __init__ (
112
121
self ,
113
122
weight : torch .Tensor ,
@@ -125,13 +134,114 @@ def forward(self, x):
125
134
A = qinput ,
126
135
A_scale = x_scale ,
127
136
B = self .weight ,
137
+ == == == =
138
+ def __init__ (
139
+ self ,
140
+ qweight : torch .Tensor ,
141
+ weight_scale : torch .Tensor ,
142
+ bias : torch .nn .Parameter ,
143
+ ):
144
+ super ().__init__ ()
145
+ self .qweight = torch .nn .Parameter (qweight , requires_grad = False )
146
+ self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
147
+ self .bias = bias
148
+
149
+ def forward (self , x ):
150
+ qinput , x_scale = per_tensor_quantize (x )
151
+ output = fp8_gemm (
152
+ A = qinput ,
153
+ A_scale = x_scale ,
154
+ B = self .qweight ,
128
155
B_scale = self .weight_scale ,
129
156
bias = self .bias ,
130
157
out_dtype = x .dtype ,
131
158
)
132
159
return output
133
160
134
161
162
+ # Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
163
+ class FP8StaticLinearQuantizer (torch .nn .Module ):
164
+ def __init__ (
165
+ self ,
166
+ qweight : torch .Tensor ,
167
+ weight_scale : torch .Tensor ,
168
+ bias : torch .nn .Parameter ,
169
+ quantize_output : bool = False ,
170
+ ):
171
+ super ().__init__ ()
172
+ self .qweight = torch .nn .Parameter (qweight , requires_grad = False )
173
+ self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
174
+ self .bias = bias
175
+ self .input_scale = None
176
+ self .output_scale = None
177
+ self .quantize_output = quantize_output
178
+
179
+ def forward (self , x ):
180
+ qinput , x_input_scale = per_tensor_quantize (x )
181
+ if self .input_scale is None :
182
+ self .input_scale = torch .nn .Parameter (x_input_scale )
183
+ elif x_input_scale > self .input_scale :
184
+ self .input_scale = torch .nn .Parameter (x_input_scale )
185
+ output = fp8_gemm (
186
+ A = qinput ,
187
+ A_scale = self .input_scale ,
188
+ B = self .qweight ,
189
+ B_scale = self .weight_scale ,
190
+ bias = self .bias ,
191
+ out_dtype = x .dtype ,
192
+ )
193
+
194
+ # Optionally, quantize output and record scale
195
+ if self .quantize_output :
196
+ qoutput , output_scale = per_tensor_quantize (output )
197
+ if self .output_scale is None :
198
+ self .output_scale = torch .nn .Parameter (output_scale )
199
+ elif output_scale > self .output_scale :
200
+ self .output_scale = torch .nn .Parameter (output_scale )
201
+ output = qoutput .to (output .dtype ) * output_scale
202
+
203
+ return output
204
+
205
+
206
+ # Module responsible for representing the final checkpoint representation
207
+ class FP8StaticLinear (torch .nn .Module ):
208
+ def __init__ (
209
+ self ,
210
+ qweight : torch .nn .Parameter ,
211
+ weight_scale : torch .nn .Parameter ,
212
+ bias : torch .nn .Parameter ,
213
+ input_scale : torch .nn .Parameter ,
214
+ output_scale : Optional [torch .nn .Parameter ] = None ,
215
+ ):
216
+ super ().__init__ ()
217
+ self .qweight = qweight
218
+ self .weight_scale = weight_scale
219
+ self .bias = bias
220
+ self .input_scale = input_scale
221
+ self .output_scale = output_scale
222
+
223
+ def forward (self , x ):
224
+ qinput = static_per_tensor_quantize (x , self .input_scale )
225
+ output = fp8_gemm (
226
+ A = qinput ,
227
+ A_scale = self .input_scale ,
228
+ B = self .qweight ,
229
+ >> >> >> > 3 ee9283 (Support calibrating kv cache scales )
230
+ B_scale = self .weight_scale ,
231
+ bias = self .bias ,
232
+ out_dtype = x .dtype ,
233
+ )
234
+ < << << << HEAD
235
+ == == == =
236
+
237
+ if self .output_scale :
238
+ qoutput = static_per_tensor_quantize (output , self .output_scale )
239
+ output = qoutput .to (output .dtype ) * self .output_scale
240
+
241
+ > >> >> >> 3 ee9283 (Support calibrating kv cache scales )
242
+ return output
243
+
244
+
135
245
# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
136
246
class FP8StaticLinearQuantizer (torch .nn .Module ):
137
247
def __init__ (
@@ -237,7 +347,11 @@ def quantize_weights(
237
347
quant_weight , weight_scale = per_tensor_quantize (linear .weight )
238
348
bias = copy .deepcopy (linear .bias ) if linear .bias is not None else None
239
349
quant_linear = FP8DynamicLinear (
350
+ << << << < HEAD
240
351
weight = quant_weight , weight_scale = weight_scale , bias = bias
352
+ == == == =
353
+ qweight = quant_weight , weight_scale = weight_scale , bias = bias
354
+ >> >> >> > 3 ee9283 (Support calibrating kv cache scales )
241
355
)
242
356
replace_module (model , name , quant_linear )
243
357
del linear .weight
@@ -259,7 +373,11 @@ def quantize_activations(
259
373
):
260
374
continue
261
375
quantizer = FP8StaticLinearQuantizer (
376
+ << << << < HEAD
262
377
weight = dynamic_quant_linear .weight ,
378
+ == == == =
379
+ qweight = dynamic_quant_linear .qweight ,
380
+ > >> >> >> 3 ee9283 (Support calibrating kv cache scales )
263
381
weight_scale = dynamic_quant_linear .weight_scale ,
264
382
bias = dynamic_quant_linear .bias ,
265
383
quantize_output = (
@@ -272,12 +390,22 @@ def quantize_activations(
272
390
cleanup_memory ()
273
391
274
392
# Pass through calibration data to measure activation scales
393
+ << << << < HEAD
275
394
with torch .inference_mode ():
276
395
with tqdm .tqdm (total = calibration_tokens .shape [0 ], desc = "Calibrating activation scales" ) as pbar :
277
396
for row_idx in range (calibration_tokens .shape [0 ]):
278
397
model (calibration_tokens [row_idx ].reshape (1 , - 1 ))
279
398
cleanup_memory ()
280
399
pbar .update (1 )
400
+ == == == =
401
+ with tqdm .tqdm (
402
+ total = calibration_tokens .shape [0 ], desc = "Calibrating activation scales"
403
+ ) as pbar :
404
+ for row_idx in range (calibration_tokens .shape [0 ]):
405
+ model (calibration_tokens [row_idx ].reshape (1 , - 1 ))
406
+ cleanup_memory ()
407
+ pbar .update (1 )
408
+ >> >> >> > 3 ee9283 (Support calibrating kv cache scales )
281
409
282
410
# Replace dynamic quantizer observer with StaticLinear for export
283
411
for name , quantizer in model .named_modules ():
@@ -287,7 +415,11 @@ def quantize_activations(
287
415
):
288
416
continue
289
417
static_proj = FP8StaticLinear (
418
+ << << << < HEAD
290
419
weight = quantizer .weight ,
420
+ == == == =
421
+ qweight = quantizer .qweight ,
422
+ > >> >> >> 3 ee9283 (Support calibrating kv cache scales )
291
423
weight_scale = quantizer .weight_scale ,
292
424
bias = quantizer .bias ,
293
425
input_scale = quantizer .input_scale ,
0 commit comments