Skip to content

Commit f934b0e

Browse files
committed
Fix weight name
1 parent 6323dff commit f934b0e

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

auto_fp8/quantize.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,12 @@ def forward(self, x):
143143
=======
144144
def __init__(
145145
self,
146-
qweight: torch.Tensor,
146+
weight: torch.Tensor,
147147
weight_scale: torch.Tensor,
148148
bias: torch.nn.Parameter,
149149
):
150150
super().__init__()
151-
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
151+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
152152
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
153153
self.bias = bias
154154

@@ -157,7 +157,7 @@ def forward(self, x):
157157
output = fp8_gemm(
158158
A=qinput,
159159
A_scale=x_scale,
160-
B=self.qweight,
160+
B=self.weight,
161161
B_scale=self.weight_scale,
162162
bias=self.bias,
163163
out_dtype=x.dtype,
@@ -169,13 +169,13 @@ def forward(self, x):
169169
class FP8StaticLinearQuantizer(torch.nn.Module):
170170
def __init__(
171171
self,
172-
qweight: torch.Tensor,
172+
weight: torch.Tensor,
173173
weight_scale: torch.Tensor,
174174
bias: torch.nn.Parameter,
175175
quantize_output: bool = False,
176176
):
177177
super().__init__()
178-
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
178+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
179179
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
180180
self.bias = bias
181181
self.input_scale = None
@@ -191,7 +191,7 @@ def forward(self, x):
191191
output = fp8_gemm(
192192
A=qinput,
193193
A_scale=self.input_scale,
194-
B=self.qweight,
194+
B=self.weight,
195195
B_scale=self.weight_scale,
196196
bias=self.bias,
197197
out_dtype=x.dtype,
@@ -213,14 +213,14 @@ def forward(self, x):
213213
class FP8StaticLinear(torch.nn.Module):
214214
def __init__(
215215
self,
216-
qweight: torch.nn.Parameter,
216+
weight: torch.nn.Parameter,
217217
weight_scale: torch.nn.Parameter,
218218
bias: torch.nn.Parameter,
219219
input_scale: torch.nn.Parameter,
220220
output_scale: Optional[torch.nn.Parameter] = None,
221221
):
222222
super().__init__()
223-
self.qweight = qweight
223+
self.weight = weight
224224
self.weight_scale = weight_scale
225225
self.bias = bias
226226
self.input_scale = input_scale
@@ -231,6 +231,7 @@ def forward(self, x):
231231
output = fp8_gemm(
232232
A=qinput,
233233
A_scale=self.input_scale,
234+
<<<<<<< HEAD
234235
B=self.qweight,
235236
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
236237
B_scale=self.weight_scale,
@@ -314,6 +315,8 @@ def forward(self, x):
314315
output = fp8_gemm(
315316
A=qinput,
316317
A_scale=self.input_scale,
318+
=======
319+
>>>>>>> def2049 (Fix weight name)
317320
B=self.weight,
318321
B_scale=self.weight_scale,
319322
bias=self.bias,
@@ -353,11 +356,15 @@ def quantize_weights(
353356
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
354357
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
355358
quant_linear = FP8DynamicLinear(
359+
<<<<<<< HEAD
356360
<<<<<<< HEAD
357361
weight=quant_weight, weight_scale=weight_scale, bias=bias
358362
=======
359363
qweight=quant_weight, weight_scale=weight_scale, bias=bias
360364
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
365+
=======
366+
weight=quant_weight, weight_scale=weight_scale, bias=bias
367+
>>>>>>> def2049 (Fix weight name)
361368
)
362369
replace_module(model, name, quant_linear)
363370
del linear.weight
@@ -379,11 +386,15 @@ def quantize_activations(
379386
):
380387
continue
381388
quantizer = FP8StaticLinearQuantizer(
389+
<<<<<<< HEAD
382390
<<<<<<< HEAD
383391
weight=dynamic_quant_linear.weight,
384392
=======
385393
qweight=dynamic_quant_linear.qweight,
386394
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
395+
=======
396+
weight=dynamic_quant_linear.weight,
397+
>>>>>>> def2049 (Fix weight name)
387398
weight_scale=dynamic_quant_linear.weight_scale,
388399
bias=dynamic_quant_linear.bias,
389400
quantize_output=(
@@ -421,11 +432,15 @@ def quantize_activations(
421432
):
422433
continue
423434
static_proj = FP8StaticLinear(
435+
<<<<<<< HEAD
424436
<<<<<<< HEAD
425437
weight=quantizer.weight,
426438
=======
427439
qweight=quantizer.qweight,
428440
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
441+
=======
442+
weight=quantizer.weight,
443+
>>>>>>> def2049 (Fix weight name)
429444
weight_scale=quantizer.weight_scale,
430445
bias=quantizer.bias,
431446
input_scale=quantizer.input_scale,

0 commit comments

Comments
 (0)