@@ -110,12 +110,12 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
110
110
class FP8DynamicLinear (torch .nn .Module ):
111
111
def __init__ (
112
112
self ,
113
- qweight : torch .Tensor ,
113
+ weight : torch .Tensor ,
114
114
weight_scale : torch .Tensor ,
115
115
bias : torch .nn .Parameter ,
116
116
):
117
117
super ().__init__ ()
118
- self .qweight = torch .nn .Parameter (qweight , requires_grad = False )
118
+ self .weight = torch .nn .Parameter (weight , requires_grad = False )
119
119
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
120
120
self .bias = bias
121
121
@@ -124,7 +124,7 @@ def forward(self, x):
124
124
output = fp8_gemm (
125
125
A = qinput ,
126
126
A_scale = x_scale ,
127
- B = self .qweight ,
127
+ B = self .weight ,
128
128
B_scale = self .weight_scale ,
129
129
bias = self .bias ,
130
130
out_dtype = x .dtype ,
@@ -136,13 +136,13 @@ def forward(self, x):
136
136
class FP8StaticLinearQuantizer (torch .nn .Module ):
137
137
def __init__ (
138
138
self ,
139
- qweight : torch .Tensor ,
139
+ weight : torch .Tensor ,
140
140
weight_scale : torch .Tensor ,
141
141
bias : torch .nn .Parameter ,
142
142
quantize_output : bool = False ,
143
143
):
144
144
super ().__init__ ()
145
- self .qweight = torch .nn .Parameter (qweight , requires_grad = False )
145
+ self .weight = torch .nn .Parameter (weight , requires_grad = False )
146
146
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
147
147
self .bias = bias
148
148
self .input_scale = None
@@ -158,7 +158,7 @@ def forward(self, x):
158
158
output = fp8_gemm (
159
159
A = qinput ,
160
160
A_scale = self .input_scale ,
161
- B = self .qweight ,
161
+ B = self .weight ,
162
162
B_scale = self .weight_scale ,
163
163
bias = self .bias ,
164
164
out_dtype = x .dtype ,
@@ -180,14 +180,14 @@ def forward(self, x):
180
180
class FP8StaticLinear (torch .nn .Module ):
181
181
def __init__ (
182
182
self ,
183
- qweight : torch .nn .Parameter ,
183
+ weight : torch .nn .Parameter ,
184
184
weight_scale : torch .nn .Parameter ,
185
185
bias : torch .nn .Parameter ,
186
186
input_scale : torch .nn .Parameter ,
187
187
output_scale : Optional [torch .nn .Parameter ] = None ,
188
188
):
189
189
super ().__init__ ()
190
- self .qweight = qweight
190
+ self .weight = weight
191
191
self .weight_scale = weight_scale
192
192
self .bias = bias
193
193
self .input_scale = input_scale
@@ -198,7 +198,7 @@ def forward(self, x):
198
198
output = fp8_gemm (
199
199
A = qinput ,
200
200
A_scale = self .input_scale ,
201
- B = self .qweight ,
201
+ B = self .weight ,
202
202
B_scale = self .weight_scale ,
203
203
bias = self .bias ,
204
204
out_dtype = x .dtype ,
@@ -237,7 +237,7 @@ def quantize_weights(
237
237
quant_weight , weight_scale = per_tensor_quantize (linear .weight )
238
238
bias = copy .deepcopy (linear .bias ) if linear .bias is not None else None
239
239
quant_linear = FP8DynamicLinear (
240
- qweight = quant_weight , weight_scale = weight_scale , bias = bias
240
+ weight = quant_weight , weight_scale = weight_scale , bias = bias
241
241
)
242
242
replace_module (model , name , quant_linear )
243
243
del linear .weight
@@ -259,7 +259,7 @@ def quantize_activations(
259
259
):
260
260
continue
261
261
quantizer = FP8StaticLinearQuantizer (
262
- qweight = dynamic_quant_linear .qweight ,
262
+ weight = dynamic_quant_linear .weight ,
263
263
weight_scale = dynamic_quant_linear .weight_scale ,
264
264
bias = dynamic_quant_linear .bias ,
265
265
quantize_output = (
@@ -288,7 +288,7 @@ def quantize_activations(
288
288
):
289
289
continue
290
290
static_proj = FP8StaticLinear (
291
- qweight = quantizer .qweight ,
291
+ weight = quantizer .weight ,
292
292
weight_scale = quantizer .weight_scale ,
293
293
bias = quantizer .bias ,
294
294
input_scale = quantizer .input_scale ,
0 commit comments