Skip to content

Commit def2049

Browse files
committed
Fix weight name
1 parent 959bdbc commit def2049

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

auto_fp8/quantize.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
110110
class FP8DynamicLinear(torch.nn.Module):
111111
def __init__(
112112
self,
113-
qweight: torch.Tensor,
113+
weight: torch.Tensor,
114114
weight_scale: torch.Tensor,
115115
bias: torch.nn.Parameter,
116116
):
117117
super().__init__()
118-
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
118+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
119119
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
120120
self.bias = bias
121121

@@ -124,7 +124,7 @@ def forward(self, x):
124124
output = fp8_gemm(
125125
A=qinput,
126126
A_scale=x_scale,
127-
B=self.qweight,
127+
B=self.weight,
128128
B_scale=self.weight_scale,
129129
bias=self.bias,
130130
out_dtype=x.dtype,
@@ -136,13 +136,13 @@ def forward(self, x):
136136
class FP8StaticLinearQuantizer(torch.nn.Module):
137137
def __init__(
138138
self,
139-
qweight: torch.Tensor,
139+
weight: torch.Tensor,
140140
weight_scale: torch.Tensor,
141141
bias: torch.nn.Parameter,
142142
quantize_output: bool = False,
143143
):
144144
super().__init__()
145-
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
145+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
146146
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
147147
self.bias = bias
148148
self.input_scale = None
@@ -158,7 +158,7 @@ def forward(self, x):
158158
output = fp8_gemm(
159159
A=qinput,
160160
A_scale=self.input_scale,
161-
B=self.qweight,
161+
B=self.weight,
162162
B_scale=self.weight_scale,
163163
bias=self.bias,
164164
out_dtype=x.dtype,
@@ -180,14 +180,14 @@ def forward(self, x):
180180
class FP8StaticLinear(torch.nn.Module):
181181
def __init__(
182182
self,
183-
qweight: torch.nn.Parameter,
183+
weight: torch.nn.Parameter,
184184
weight_scale: torch.nn.Parameter,
185185
bias: torch.nn.Parameter,
186186
input_scale: torch.nn.Parameter,
187187
output_scale: Optional[torch.nn.Parameter] = None,
188188
):
189189
super().__init__()
190-
self.qweight = qweight
190+
self.weight = weight
191191
self.weight_scale = weight_scale
192192
self.bias = bias
193193
self.input_scale = input_scale
@@ -198,7 +198,7 @@ def forward(self, x):
198198
output = fp8_gemm(
199199
A=qinput,
200200
A_scale=self.input_scale,
201-
B=self.qweight,
201+
B=self.weight,
202202
B_scale=self.weight_scale,
203203
bias=self.bias,
204204
out_dtype=x.dtype,
@@ -237,7 +237,7 @@ def quantize_weights(
237237
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
238238
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
239239
quant_linear = FP8DynamicLinear(
240-
qweight=quant_weight, weight_scale=weight_scale, bias=bias
240+
weight=quant_weight, weight_scale=weight_scale, bias=bias
241241
)
242242
replace_module(model, name, quant_linear)
243243
del linear.weight
@@ -259,7 +259,7 @@ def quantize_activations(
259259
):
260260
continue
261261
quantizer = FP8StaticLinearQuantizer(
262-
qweight=dynamic_quant_linear.qweight,
262+
weight=dynamic_quant_linear.weight,
263263
weight_scale=dynamic_quant_linear.weight_scale,
264264
bias=dynamic_quant_linear.bias,
265265
quantize_output=(
@@ -288,7 +288,7 @@ def quantize_activations(
288288
):
289289
continue
290290
static_proj = FP8StaticLinear(
291-
qweight=quantizer.qweight,
291+
weight=quantizer.weight,
292292
weight_scale=quantizer.weight_scale,
293293
bias=quantizer.bias,
294294
input_scale=quantizer.input_scale,

0 commit comments

Comments
 (0)