@@ -143,12 +143,12 @@ def forward(self, x):
143
143
== == == =
144
144
def __init__ (
145
145
self ,
146
- qweight : torch .Tensor ,
146
+ weight : torch .Tensor ,
147
147
weight_scale : torch .Tensor ,
148
148
bias : torch .nn .Parameter ,
149
149
):
150
150
super ().__init__ ()
151
- self .qweight = torch .nn .Parameter (qweight , requires_grad = False )
151
+ self .weight = torch .nn .Parameter (weight , requires_grad = False )
152
152
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
153
153
self .bias = bias
154
154
@@ -157,7 +157,7 @@ def forward(self, x):
157
157
output = fp8_gemm (
158
158
A = qinput ,
159
159
A_scale = x_scale ,
160
- B = self .qweight ,
160
+ B = self .weight ,
161
161
B_scale = self .weight_scale ,
162
162
bias = self .bias ,
163
163
out_dtype = x .dtype ,
@@ -169,13 +169,13 @@ def forward(self, x):
169
169
class FP8StaticLinearQuantizer (torch .nn .Module ):
170
170
def __init__ (
171
171
self ,
172
- qweight : torch .Tensor ,
172
+ weight : torch .Tensor ,
173
173
weight_scale : torch .Tensor ,
174
174
bias : torch .nn .Parameter ,
175
175
quantize_output : bool = False ,
176
176
):
177
177
super ().__init__ ()
178
- self .qweight = torch .nn .Parameter (qweight , requires_grad = False )
178
+ self .weight = torch .nn .Parameter (weight , requires_grad = False )
179
179
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
180
180
self .bias = bias
181
181
self .input_scale = None
@@ -191,7 +191,7 @@ def forward(self, x):
191
191
output = fp8_gemm (
192
192
A = qinput ,
193
193
A_scale = self .input_scale ,
194
- B = self .qweight ,
194
+ B = self .weight ,
195
195
B_scale = self .weight_scale ,
196
196
bias = self .bias ,
197
197
out_dtype = x .dtype ,
@@ -213,14 +213,14 @@ def forward(self, x):
213
213
class FP8StaticLinear (torch .nn .Module ):
214
214
def __init__ (
215
215
self ,
216
- qweight : torch .nn .Parameter ,
216
+ weight : torch .nn .Parameter ,
217
217
weight_scale : torch .nn .Parameter ,
218
218
bias : torch .nn .Parameter ,
219
219
input_scale : torch .nn .Parameter ,
220
220
output_scale : Optional [torch .nn .Parameter ] = None ,
221
221
):
222
222
super ().__init__ ()
223
- self .qweight = qweight
223
+ self .weight = weight
224
224
self .weight_scale = weight_scale
225
225
self .bias = bias
226
226
self .input_scale = input_scale
@@ -231,6 +231,7 @@ def forward(self, x):
231
231
output = fp8_gemm (
232
232
A = qinput ,
233
233
A_scale = self .input_scale ,
234
+ << << << < HEAD
234
235
B = self .qweight ,
235
236
>> >> >> > 3 ee9283 (Support calibrating kv cache scales )
236
237
B_scale = self .weight_scale ,
@@ -314,6 +315,8 @@ def forward(self, x):
314
315
output = fp8_gemm (
315
316
A = qinput ,
316
317
A_scale = self .input_scale ,
318
+ == == == =
319
+ > >> >> >> def2049 (Fix weight name )
317
320
B = self .weight ,
318
321
B_scale = self .weight_scale ,
319
322
bias = self .bias ,
@@ -353,11 +356,15 @@ def quantize_weights(
353
356
quant_weight , weight_scale = per_tensor_quantize (linear .weight )
354
357
bias = copy .deepcopy (linear .bias ) if linear .bias is not None else None
355
358
quant_linear = FP8DynamicLinear (
359
+ << << << < HEAD
356
360
<< < << << HEAD
357
361
weight = quant_weight , weight_scale = weight_scale , bias = bias
358
362
== == == =
359
363
qweight = quant_weight , weight_scale = weight_scale , bias = bias
360
364
>> >> >> > 3 ee9283 (Support calibrating kv cache scales )
365
+ == == == =
366
+ weight = quant_weight , weight_scale = weight_scale , bias = bias
367
+ >> >> >> > def2049 (Fix weight name )
361
368
)
362
369
replace_module (model , name , quant_linear )
363
370
del linear .weight
@@ -379,11 +386,15 @@ def quantize_activations(
379
386
):
380
387
continue
381
388
quantizer = FP8StaticLinearQuantizer (
389
+ << << << < HEAD
382
390
<< < << << HEAD
383
391
weight = dynamic_quant_linear .weight ,
384
392
== == == =
385
393
qweight = dynamic_quant_linear .qweight ,
386
394
>> >> >> > 3 ee9283 (Support calibrating kv cache scales )
395
+ == == == =
396
+ weight = dynamic_quant_linear .weight ,
397
+ >> >> >> > def2049 (Fix weight name )
387
398
weight_scale = dynamic_quant_linear .weight_scale ,
388
399
bias = dynamic_quant_linear .bias ,
389
400
quantize_output = (
@@ -421,11 +432,15 @@ def quantize_activations(
421
432
):
422
433
continue
423
434
static_proj = FP8StaticLinear (
435
+ << << << < HEAD
424
436
<< < << << HEAD
425
437
weight = quantizer .weight ,
426
438
== == == =
427
439
qweight = quantizer .qweight ,
428
440
>> >> >> > 3 ee9283 (Support calibrating kv cache scales )
441
+ == == == =
442
+ weight = quantizer .weight ,
443
+ >> >> >> > def2049 (Fix weight name )
429
444
weight_scale = quantizer .weight_scale ,
430
445
bias = quantizer .bias ,
431
446
input_scale = quantizer .input_scale ,
0 commit comments