Skip to content

Commit 424a4c9

Browse files
committed
Update quantize.py
1 parent 608baca commit 424a4c9

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

quantize.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def cleanup_memory():
3333

3434
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
3535
"""Quantize a tensor using per-tensor static scaling factor.
36-
3736
Args:
3837
tensor: The input tensor.
3938
"""
@@ -83,11 +82,12 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
8382

8483

8584
class FP8StaticLinearQuantizer(torch.nn.Module):
86-
def __init__(self, qweight, weight_scale):
85+
def __init__(self, qweight, weight_scale, bias):
8786
super().__init__()
8887
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
8988
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
9089
self.act_scale = None
90+
self.bias = bias
9191

9292
def forward(self, x):
9393
# Dynamically quantize
@@ -105,18 +105,19 @@ def forward(self, x):
105105
A_scale=self.act_scale,
106106
B=self.weight,
107107
B_scale=self.weight_scale,
108-
bias=None,
108+
bias=self.bias,
109109
out_dtype=x.dtype,
110110
)
111111
return output
112112

113113

114114
class FP8StaticLinear(torch.nn.Module):
115-
def __init__(self, qweight, weight_scale, act_scale=0.0):
115+
def __init__(self, qweight, weight_scale, bias, act_scale=0.0):
116116
super().__init__()
117117
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
118118
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
119119
self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
120+
self.bias = bias
120121

121122
def per_tensor_quantize(
122123
self, tensor: torch.Tensor, inv_scale: float
@@ -135,17 +136,18 @@ def forward(self, x):
135136
A_scale=self.act_scale,
136137
B=self.weight,
137138
B_scale=self.weight_scale,
138-
bias=None,
139+
bias=self.bias,
139140
out_dtype=x.dtype,
140141
)
141142
return output
142143

143144

144145
class FP8DynamicLinear(torch.nn.Module):
145-
def __init__(self, qweight, scale):
146+
def __init__(self, qweight, scale, bias):
146147
super().__init__()
147148
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
148149
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
150+
self.bias = bias
149151

150152
def forward(self, x):
151153
qinput, x_scale = per_tensor_quantize(x)
@@ -154,7 +156,7 @@ def forward(self, x):
154156
A_scale=x_scale,
155157
B=self.weight,
156158
B_scale=self.weight_scale,
157-
bias=None,
159+
bias=self.bias,
158160
out_dtype=x.dtype,
159161
)
160162
return output
@@ -178,7 +180,7 @@ def quantize_weights(model):
178180
if not isinstance(linear, torch.nn.Linear):
179181
continue
180182
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
181-
quant_linear = FP8DynamicLinear(quant_weight, quant_scale)
183+
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, linear.bias)
182184
replace_module(model, name, quant_linear)
183185
del linear
184186
cleanup_memory()
@@ -191,7 +193,7 @@ def quantize_activations(model, calibration_tokens):
191193
if not isinstance(dynamic_quant_linear, FP8DynamicLinear):
192194
continue
193195
quantizer = FP8StaticLinearQuantizer(
194-
dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale
196+
dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale, dynamic_quant_linear.bias
195197
)
196198
replace_module(model, name, quantizer)
197199
del dynamic_quant_linear
@@ -210,14 +212,15 @@ def quantize_activations(model, calibration_tokens):
210212
if not isinstance(quantizer, FP8StaticLinearQuantizer):
211213
continue
212214
static_proj = FP8StaticLinear(
213-
quantizer.weight, quantizer.weight_scale, quantizer.act_scale
215+
quantizer.weight, quantizer.weight_scale, quantizer.bias, quantizer.act_scale
214216
)
215217
replace_module(model, name, static_proj)
216218
del quantizer
217219
cleanup_memory()
218220

219221

220222
def save_quantized_model(model, activation_scheme, save_dir):
223+
print(model)
221224
print(f"Saving the model to {save_dir}")
222225
static_q_dict = {
223226
"quantization_config": {

0 commit comments

Comments
 (0)