1
1
import gc
2
2
import re
3
- from typing import List , Tuple
3
+ from typing import Optional , Tuple
4
4
import copy
5
5
6
6
import torch
@@ -61,14 +61,22 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
61
61
return qweight , scale
62
62
63
63
64
+ def static_per_tensor_quantize (tensor : torch .Tensor , inv_scale : float ) -> torch .Tensor :
65
+ finfo = torch .finfo (torch .float8_e4m3fn )
66
+ qweight = (tensor / inv_scale ).clamp (min = finfo .min , max = finfo .max )
67
+ return qweight .to (torch .float8_e4m3fn )
68
+
69
+
64
70
def fp8_gemm (A , A_scale , B , B_scale , bias , out_dtype ):
65
71
if A .numel () == 0 :
66
72
# Deal with empty tensors (triggeted by empty MoE experts)
67
73
return torch .empty (size = (0 , B .shape [0 ]), dtype = out_dtype , device = A .device )
68
-
69
- native_fp8_support = (
70
- torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
71
- )
74
+
75
+ # TODO: Disable native fp8 gemm for now, always just dequantize
76
+ # native_fp8_support = (
77
+ # torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
78
+ # )
79
+ native_fp8_support = False
72
80
if native_fp8_support :
73
81
need_reshape = A .dim () == 3
74
82
if need_reshape :
@@ -98,25 +106,24 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
98
106
return output
99
107
100
108
101
- class FP8StaticLinearQuantizer (torch .nn .Module ):
109
+ # Class responsible for quantizing weights
110
+ class FP8DynamicLinear (torch .nn .Module ):
102
111
def __init__ (
103
- self , qweight : torch .Tensor , weight_scale : torch .Tensor , bias : torch .Tensor
112
+ self ,
113
+ weight : torch .Tensor ,
114
+ weight_scale : torch .Tensor ,
115
+ bias : torch .nn .Parameter ,
104
116
):
105
117
super ().__init__ ()
106
- self .weight = torch .nn .Parameter (qweight , requires_grad = False )
118
+ self .weight = torch .nn .Parameter (weight , requires_grad = False )
107
119
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
108
- self .input_scale = None
109
120
self .bias = bias
110
121
111
122
def forward (self , x ):
112
- qinput , x_input_scale = per_tensor_quantize (x )
113
- if self .input_scale is None :
114
- self .input_scale = torch .nn .Parameter (x_input_scale )
115
- elif x_input_scale > self .input_scale :
116
- self .input_scale = torch .nn .Parameter (x_input_scale )
123
+ qinput , x_scale = per_tensor_quantize (x )
117
124
output = fp8_gemm (
118
125
A = qinput ,
119
- A_scale = self . input_scale ,
126
+ A_scale = x_scale ,
120
127
B = self .weight ,
121
128
B_scale = self .weight_scale ,
122
129
bias = self .bias ,
@@ -125,29 +132,29 @@ def forward(self, x):
125
132
return output
126
133
127
134
128
- class FP8StaticLinear (torch .nn .Module ):
135
+ # Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
136
+ class FP8StaticLinearQuantizer (torch .nn .Module ):
129
137
def __init__ (
130
138
self ,
131
- qweight : torch .Tensor ,
139
+ weight : torch .Tensor ,
132
140
weight_scale : torch .Tensor ,
133
- bias : torch .Tensor ,
134
- input_scale : float = 1.0 ,
141
+ bias : torch .nn . Parameter ,
142
+ quantize_output : bool = False ,
135
143
):
136
144
super ().__init__ ()
137
- self .weight = torch .nn .Parameter (qweight , requires_grad = False )
145
+ self .weight = torch .nn .Parameter (weight , requires_grad = False )
138
146
self .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
139
- self .input_scale = torch .nn .Parameter (input_scale , requires_grad = False )
140
147
self .bias = bias
141
-
142
- def per_tensor_quantize (
143
- self , tensor : torch .Tensor , inv_scale : float
144
- ) -> torch .Tensor :
145
- finfo = torch .finfo (torch .float8_e4m3fn )
146
- qweight = (tensor / inv_scale ).clamp (min = finfo .min , max = finfo .max )
147
- return qweight .to (torch .float8_e4m3fn )
148
+ self .input_scale = None
149
+ self .output_scale = None
150
+ self .quantize_output = quantize_output
148
151
149
152
def forward (self , x ):
150
- qinput = self .per_tensor_quantize (x , inv_scale = self .input_scale )
153
+ qinput , x_input_scale = per_tensor_quantize (x )
154
+ if self .input_scale is None :
155
+ self .input_scale = torch .nn .Parameter (x_input_scale , requires_grad = False )
156
+ elif x_input_scale > self .input_scale :
157
+ self .input_scale = torch .nn .Parameter (x_input_scale , requires_grad = False )
151
158
output = fp8_gemm (
152
159
A = qinput ,
153
160
A_scale = self .input_scale ,
@@ -156,26 +163,51 @@ def forward(self, x):
156
163
bias = self .bias ,
157
164
out_dtype = x .dtype ,
158
165
)
166
+
167
+ # Optionally, quantize output and record scale
168
+ if self .quantize_output :
169
+ qoutput , output_scale = per_tensor_quantize (output )
170
+ if self .output_scale is None :
171
+ self .output_scale = torch .nn .Parameter (output_scale , requires_grad = False )
172
+ elif output_scale > self .output_scale :
173
+ self .output_scale = torch .nn .Parameter (output_scale , requires_grad = False )
174
+ output = qoutput .to (output .dtype ) * output_scale
175
+
159
176
return output
160
177
161
178
162
- class FP8DynamicLinear (torch .nn .Module ):
163
- def __init__ (self , qweight : torch .Tensor , scale : torch .Tensor , bias : torch .Tensor ):
179
+ # Module responsible for representing the final checkpoint representation
180
+ class FP8StaticLinear (torch .nn .Module ):
181
+ def __init__ (
182
+ self ,
183
+ weight : torch .nn .Parameter ,
184
+ weight_scale : torch .nn .Parameter ,
185
+ bias : torch .nn .Parameter ,
186
+ input_scale : torch .nn .Parameter ,
187
+ output_scale : Optional [torch .nn .Parameter ] = None ,
188
+ ):
164
189
super ().__init__ ()
165
- self .weight = torch . nn . Parameter ( qweight , requires_grad = False )
166
- self .weight_scale = torch . nn . Parameter ( scale , requires_grad = False )
190
+ self .weight = weight
191
+ self .weight_scale = weight_scale
167
192
self .bias = bias
193
+ self .input_scale = input_scale
194
+ self .output_scale = output_scale
168
195
169
196
def forward (self , x ):
170
- qinput , x_scale = per_tensor_quantize ( x )
197
+ qinput = static_per_tensor_quantize ( x , self . input_scale )
171
198
output = fp8_gemm (
172
199
A = qinput ,
173
- A_scale = x_scale ,
200
+ A_scale = self . input_scale ,
174
201
B = self .weight ,
175
202
B_scale = self .weight_scale ,
176
203
bias = self .bias ,
177
204
out_dtype = x .dtype ,
178
205
)
206
+
207
+ if self .output_scale :
208
+ qoutput = static_per_tensor_quantize (output , self .output_scale )
209
+ output = qoutput .to (output .dtype ) * self .output_scale
210
+
179
211
return output
180
212
181
213
@@ -194,7 +226,6 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.
194
226
def quantize_weights (
195
227
model : AutoModelForCausalLM ,
196
228
quantize_config : BaseQuantizeConfig ,
197
- ignored_layers : List [str ] = [],
198
229
):
199
230
named_modules = list (model .named_modules ())
200
231
for name , linear in tqdm .tqdm (named_modules , desc = "Quantizing weights" ):
@@ -203,9 +234,11 @@ def quantize_weights(
203
234
or name in quantize_config .ignored_layers
204
235
):
205
236
continue
206
- quant_weight , quant_scale = per_tensor_quantize (linear .weight )
237
+ quant_weight , weight_scale = per_tensor_quantize (linear .weight )
207
238
bias = copy .deepcopy (linear .bias ) if linear .bias is not None else None
208
- quant_linear = FP8DynamicLinear (quant_weight , quant_scale , bias )
239
+ quant_linear = FP8DynamicLinear (
240
+ weight = quant_weight , weight_scale = weight_scale , bias = bias
241
+ )
209
242
replace_module (model , name , quant_linear )
210
243
del linear .weight
211
244
del linear .bias
@@ -217,7 +250,6 @@ def quantize_activations(
217
250
model : AutoModelForCausalLM ,
218
251
quantize_config : BaseQuantizeConfig ,
219
252
calibration_tokens ,
220
- ignored_layers : List [str ] = [],
221
253
):
222
254
# Replace weight quantizer with a dynamic activation quantizer observer
223
255
for name , dynamic_quant_linear in model .named_modules ():
@@ -227,9 +259,13 @@ def quantize_activations(
227
259
):
228
260
continue
229
261
quantizer = FP8StaticLinearQuantizer (
230
- dynamic_quant_linear .weight ,
231
- dynamic_quant_linear .weight_scale ,
232
- dynamic_quant_linear .bias ,
262
+ weight = dynamic_quant_linear .weight ,
263
+ weight_scale = dynamic_quant_linear .weight_scale ,
264
+ bias = dynamic_quant_linear .bias ,
265
+ quantize_output = (
266
+ hasattr (quantize_config , "kv_cache_quant_layers" )
267
+ and name in quantize_config .kv_cache_quant_layers
268
+ ),
233
269
)
234
270
replace_module (model , name , quantizer )
235
271
del dynamic_quant_linear
@@ -251,21 +287,45 @@ def quantize_activations(
251
287
):
252
288
continue
253
289
static_proj = FP8StaticLinear (
254
- quantizer .weight ,
255
- quantizer .weight_scale ,
256
- quantizer .bias ,
257
- quantizer .input_scale ,
290
+ weight = quantizer .weight ,
291
+ weight_scale = quantizer .weight_scale ,
292
+ bias = quantizer .bias ,
293
+ input_scale = quantizer .input_scale ,
294
+ output_scale = quantizer .output_scale ,
258
295
)
259
296
replace_module (model , name , static_proj )
260
297
del quantizer
261
298
cleanup_memory ()
262
299
300
+ # Post-process step for kv cache scales to take the k/v module
301
+ # `output_scale` parameters, take the max of them, and store them in
302
+ # the parent attention module as `kv_scale`
303
+ # NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
304
+ if hasattr (quantize_config , "kv_cache_quant_layers" ):
305
+ # Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
306
+ # so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
307
+ kv_proj_pairs = zip (* [iter (quantize_config .kv_cache_quant_layers )]* 2 )
308
+ for k_proj_name , v_proj_name in kv_proj_pairs :
309
+ parent_module_name = "." .join (k_proj_name .split ("." )[:- 1 ])
310
+ assert parent_module_name == "." .join (v_proj_name .split ("." )[:- 1 ])
311
+ parent_module = dict (model .named_modules ())[parent_module_name ]
312
+
313
+ k_proj = dict (model .named_modules ())[k_proj_name ]
314
+ v_proj = dict (model .named_modules ())[v_proj_name ]
315
+
316
+ kv_scale = max (k_proj .output_scale , v_proj .output_scale )
317
+ parent_module .kv_scale = torch .nn .Parameter (kv_scale , requires_grad = False )
318
+
319
+ # Remove output_scale from k_proj and v_proj
320
+ k_proj .output_scale = None
321
+ v_proj .output_scale = None
322
+ cleanup_memory ()
323
+
263
324
264
325
def save_quantized_model (
265
326
model : AutoModelForCausalLM ,
266
327
quant_config : BaseQuantizeConfig ,
267
328
save_dir : str ,
268
- ignored_layers : List [str ] = [],
269
329
):
270
330
print (model )
271
331
print (f"Saving the model to { save_dir } " )
@@ -276,6 +336,8 @@ def save_quantized_model(
276
336
"ignored_layers" : quant_config .ignored_layers ,
277
337
}
278
338
}
339
+ if hasattr (quant_config , "kv_cache_quant_layers" ):
340
+ static_q_dict ["quantization_config" ]["kv_cache_scheme" ] = "static"
279
341
model .config .update (static_q_dict )
280
342
model .save_pretrained (save_dir )
281
343
tokenizer = AutoTokenizer .from_pretrained (model .config ._name_or_path )
0 commit comments