1
1
import gc
2
2
import re
3
- from typing import List , Tuple
3
+ from typing import Optional , Tuple
4
4
import copy
5
5
6
6
import torch
@@ -61,13 +61,21 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
61
61
return qweight , scale
62
62
63
63
64
+ def static_per_tensor_quantize (tensor : torch .Tensor , inv_scale : float ) -> torch .Tensor :
65
+ finfo = torch .finfo (torch .float8_e4m3fn )
66
+ qweight = (tensor / inv_scale ).clamp (min = finfo .min , max = finfo .max )
67
+ return qweight .to (torch .float8_e4m3fn )
68
+
69
+
64
70
def fp8_gemm (A , A_scale , B , B_scale , bias , out_dtype ):
65
71
if A .numel () == 0 :
66
72
# Deal with empty tensors (triggeted by empty MoE experts)
67
73
return torch .empty (size = (0 , B .shape [0 ]), dtype = out_dtype , device = A .device )
68
-
74
+
69
75
native_fp8_support = (
70
- torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
76
+ torch .cuda .is_available ()
77
+ and torch .cuda .get_device_capability () >= (8 , 9 )
78
+ and False
71
79
)
72
80
if native_fp8_support :
73
81
need_reshape = A .dim () == 3
@@ -98,84 +106,108 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
98
106
return output
99
107
100
108
101
- class FP8StaticLinearQuantizer (torch .nn .Module ):
109
+ # Class responsible for quantizing weights
110
+ class FP8DynamicLinear (torch .nn .Module ):
102
111
def __init__ (
103
- self , qweight : torch .Tensor , weight_scale : torch .Tensor , bias : torch .Tensor
112
+ self ,
113
+ qweight : torch .Tensor ,
114
+ weight_scale : torch .Tensor ,
115
+ bias : torch .nn .Parameter ,
104
116
):
105
117
super ().__init__ ()
106
- self .weight = torch .nn .Parameter (qweight , requires_grad = False )
118
+ self .qweight = torch .nn .Parameter (qweight , requires_grad = False )
107
119
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
108
- self .input_scale = None
109
120
self .bias = bias
110
121
111
122
def forward (self , x ):
112
- qinput , x_input_scale = per_tensor_quantize (x )
113
- if self .input_scale is None :
114
- self .input_scale = torch .nn .Parameter (x_input_scale )
115
- elif x_input_scale > self .input_scale :
116
- self .input_scale = torch .nn .Parameter (x_input_scale )
123
+ qinput , x_scale = per_tensor_quantize (x )
117
124
output = fp8_gemm (
118
125
A = qinput ,
119
- A_scale = self . input_scale ,
120
- B = self .weight ,
126
+ A_scale = x_scale ,
127
+ B = self .qweight ,
121
128
B_scale = self .weight_scale ,
122
129
bias = self .bias ,
123
130
out_dtype = x .dtype ,
124
131
)
125
132
return output
126
133
127
134
128
- class FP8StaticLinear (torch .nn .Module ):
135
+ # Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
136
+ class FP8StaticLinearQuantizer (torch .nn .Module ):
129
137
def __init__ (
130
138
self ,
131
139
qweight : torch .Tensor ,
132
140
weight_scale : torch .Tensor ,
133
- bias : torch .Tensor ,
134
- input_scale : float = 1.0 ,
141
+ bias : torch .nn . Parameter ,
142
+ quantize_output : bool = False ,
135
143
):
136
144
super ().__init__ ()
137
- self .weight = torch .nn .Parameter (qweight , requires_grad = False )
145
+ self .qweight = torch .nn .Parameter (qweight , requires_grad = False )
138
146
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
139
- self .input_scale = torch .nn .Parameter (input_scale , requires_grad = False )
140
147
self .bias = bias
141
-
142
- def per_tensor_quantize (
143
- self , tensor : torch .Tensor , inv_scale : float
144
- ) -> torch .Tensor :
145
- finfo = torch .finfo (torch .float8_e4m3fn )
146
- qweight = (tensor / inv_scale ).clamp (min = finfo .min , max = finfo .max )
147
- return qweight .to (torch .float8_e4m3fn )
148
+ self .input_scale = None
149
+ self .output_scale = None
150
+ self .quantize_output = quantize_output
148
151
149
152
def forward (self , x ):
150
- qinput = self .per_tensor_quantize (x , inv_scale = self .input_scale )
153
+ qinput , x_input_scale = per_tensor_quantize (x )
154
+ if self .input_scale is None :
155
+ self .input_scale = torch .nn .Parameter (x_input_scale )
156
+ elif x_input_scale > self .input_scale :
157
+ self .input_scale = torch .nn .Parameter (x_input_scale )
151
158
output = fp8_gemm (
152
159
A = qinput ,
153
160
A_scale = self .input_scale ,
154
- B = self .weight ,
161
+ B = self .qweight ,
155
162
B_scale = self .weight_scale ,
156
163
bias = self .bias ,
157
164
out_dtype = x .dtype ,
158
165
)
166
+
167
+ # Optionally, quantize output and record scale
168
+ if self .quantize_output :
169
+ qoutput , output_scale = per_tensor_quantize (output )
170
+ if self .output_scale is None :
171
+ self .output_scale = torch .nn .Parameter (output_scale )
172
+ elif output_scale > self .output_scale :
173
+ self .output_scale = torch .nn .Parameter (output_scale )
174
+ output = qoutput .to (output .dtype ) * output_scale
175
+
159
176
return output
160
177
161
178
162
- class FP8DynamicLinear (torch .nn .Module ):
163
- def __init__ (self , qweight : torch .Tensor , scale : torch .Tensor , bias : torch .Tensor ):
179
+ # Module responsible for representing the final checkpoint representation
180
+ class FP8StaticLinear (torch .nn .Module ):
181
+ def __init__ (
182
+ self ,
183
+ qweight : torch .nn .Parameter ,
184
+ weight_scale : torch .nn .Parameter ,
185
+ bias : torch .nn .Parameter ,
186
+ input_scale : torch .nn .Parameter ,
187
+ output_scale : Optional [torch .nn .Parameter ] = None ,
188
+ ):
164
189
super ().__init__ ()
165
- self .weight = torch . nn . Parameter ( qweight , requires_grad = False )
166
- self .weight_scale = torch . nn . Parameter ( scale , requires_grad = False )
190
+ self .qweight = qweight
191
+ self .weight_scale = weight_scale
167
192
self .bias = bias
193
+ self .input_scale = input_scale
194
+ self .output_scale = output_scale
168
195
169
196
def forward (self , x ):
170
- qinput , x_scale = per_tensor_quantize ( x )
197
+ qinput = static_per_tensor_quantize ( x , self . input_scale )
171
198
output = fp8_gemm (
172
199
A = qinput ,
173
- A_scale = x_scale ,
174
- B = self .weight ,
200
+ A_scale = self . input_scale ,
201
+ B = self .qweight ,
175
202
B_scale = self .weight_scale ,
176
203
bias = self .bias ,
177
204
out_dtype = x .dtype ,
178
205
)
206
+
207
+ if self .output_scale :
208
+ qoutput = static_per_tensor_quantize (output , self .output_scale )
209
+ output = qoutput .to (output .dtype ) * self .output_scale
210
+
179
211
return output
180
212
181
213
@@ -194,7 +226,6 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.
194
226
def quantize_weights (
195
227
model : AutoModelForCausalLM ,
196
228
quantize_config : BaseQuantizeConfig ,
197
- ignored_layers : List [str ] = [],
198
229
):
199
230
named_modules = list (model .named_modules ())
200
231
for name , linear in tqdm .tqdm (named_modules , desc = "Quantizing weights" ):
@@ -203,9 +234,11 @@ def quantize_weights(
203
234
or name in quantize_config .ignored_layers
204
235
):
205
236
continue
206
- quant_weight , quant_scale = per_tensor_quantize (linear .weight )
237
+ quant_weight , weight_scale = per_tensor_quantize (linear .weight )
207
238
bias = copy .deepcopy (linear .bias ) if linear .bias is not None else None
208
- quant_linear = FP8DynamicLinear (quant_weight , quant_scale , bias )
239
+ quant_linear = FP8DynamicLinear (
240
+ qweight = quant_weight , weight_scale = weight_scale , bias = bias
241
+ )
209
242
replace_module (model , name , quant_linear )
210
243
del linear .weight
211
244
del linear .bias
@@ -217,7 +250,6 @@ def quantize_activations(
217
250
model : AutoModelForCausalLM ,
218
251
quantize_config : BaseQuantizeConfig ,
219
252
calibration_tokens ,
220
- ignored_layers : List [str ] = [],
221
253
):
222
254
# Replace weight quantizer with a dynamic activation quantizer observer
223
255
for name , dynamic_quant_linear in model .named_modules ():
@@ -227,16 +259,22 @@ def quantize_activations(
227
259
):
228
260
continue
229
261
quantizer = FP8StaticLinearQuantizer (
230
- dynamic_quant_linear .weight ,
231
- dynamic_quant_linear .weight_scale ,
232
- dynamic_quant_linear .bias ,
262
+ qweight = dynamic_quant_linear .qweight ,
263
+ weight_scale = dynamic_quant_linear .weight_scale ,
264
+ bias = dynamic_quant_linear .bias ,
265
+ quantize_output = (
266
+ hasattr (quantize_config , "kv_cache_quant_layers" )
267
+ and name in quantize_config .kv_cache_quant_layers
268
+ ),
233
269
)
234
270
replace_module (model , name , quantizer )
235
271
del dynamic_quant_linear
236
272
cleanup_memory ()
237
273
238
274
# Pass through calibration data to measure activation scales
239
- with tqdm .tqdm (total = calibration_tokens .shape [0 ], desc = "Calibrating activation scales" ) as pbar :
275
+ with tqdm .tqdm (
276
+ total = calibration_tokens .shape [0 ], desc = "Calibrating activation scales"
277
+ ) as pbar :
240
278
for row_idx in range (calibration_tokens .shape [0 ]):
241
279
model (calibration_tokens [row_idx ].reshape (1 , - 1 ))
242
280
cleanup_memory ()
@@ -250,10 +288,11 @@ def quantize_activations(
250
288
):
251
289
continue
252
290
static_proj = FP8StaticLinear (
253
- quantizer .weight ,
254
- quantizer .weight_scale ,
255
- quantizer .bias ,
256
- quantizer .input_scale ,
291
+ qweight = quantizer .qweight ,
292
+ weight_scale = quantizer .weight_scale ,
293
+ bias = quantizer .bias ,
294
+ input_scale = quantizer .input_scale ,
295
+ output_scale = quantizer .output_scale ,
257
296
)
258
297
replace_module (model , name , static_proj )
259
298
del quantizer
@@ -264,7 +303,6 @@ def save_quantized_model(
264
303
model : AutoModelForCausalLM ,
265
304
quant_config : BaseQuantizeConfig ,
266
305
save_dir : str ,
267
- ignored_layers : List [str ] = [],
268
306
):
269
307
print (model )
270
308
print (f"Saving the model to { save_dir } " )
@@ -275,6 +313,8 @@ def save_quantized_model(
275
313
"ignored_layers" : quant_config .ignored_layers ,
276
314
}
277
315
}
316
+ if hasattr (quant_config , "kv_cache_quant_layers" ):
317
+ static_q_dict ["quantization_config" ]["kv_cache_scheme" ] = "static"
278
318
model .config .update (static_q_dict )
279
319
model .save_pretrained (save_dir )
280
320
tokenizer = AutoTokenizer .from_pretrained (model .config ._name_or_path )
0 commit comments