@@ -152,9 +152,9 @@ def __init__(
152
152
def forward (self , x ):
153
153
qinput , x_input_scale = per_tensor_quantize (x )
154
154
if self .input_scale is None :
155
- self .input_scale = torch .nn .Parameter (x_input_scale )
155
+ self .input_scale = torch .nn .Parameter (x_input_scale , requires_grad = False )
156
156
elif x_input_scale > self .input_scale :
157
- self .input_scale = torch .nn .Parameter (x_input_scale )
157
+ self .input_scale = torch .nn .Parameter (x_input_scale , requires_grad = False )
158
158
output = fp8_gemm (
159
159
A = qinput ,
160
160
A_scale = self .input_scale ,
@@ -168,9 +168,9 @@ def forward(self, x):
168
168
if self .quantize_output :
169
169
qoutput , output_scale = per_tensor_quantize (output )
170
170
if self .output_scale is None :
171
- self .output_scale = torch .nn .Parameter (output_scale )
171
+ self .output_scale = torch .nn .Parameter (output_scale , requires_grad = False )
172
172
elif output_scale > self .output_scale :
173
- self .output_scale = torch .nn .Parameter (output_scale )
173
+ self .output_scale = torch .nn .Parameter (output_scale , requires_grad = False )
174
174
output = qoutput .to (output .dtype ) * output_scale
175
175
176
176
return output
@@ -307,6 +307,30 @@ def quantize_activations(
307
307
del quantizer
308
308
cleanup_memory ()
309
309
310
+ # Post-process step for kv cache scales to take the k/v module
311
+ # `output_scale` parameters, take the max of them, and store them in
312
+ # the parent attention module as `kv_scale`
313
+ # NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
314
+ if hasattr (quantize_config , "kv_cache_quant_layers" ):
315
+ # Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
316
+ # so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
317
+ kv_proj_pairs = zip (* [iter (quantize_config .kv_cache_quant_layers )]* 2 )
318
+ for k_proj_name , v_proj_name in kv_proj_pairs :
319
+ parent_module_name = "." .join (k_proj_name .split ("." )[:- 1 ])
320
+ assert parent_module_name == "." .join (v_proj_name .split ("." )[:- 1 ])
321
+ parent_module = dict (model .named_modules ())[parent_module_name ]
322
+
323
+ k_proj = dict (model .named_modules ())[k_proj_name ]
324
+ v_proj = dict (model .named_modules ())[v_proj_name ]
325
+
326
+ kv_scale = max (k_proj .output_scale , v_proj .output_scale )
327
+ parent_module .kv_scale = torch .nn .Parameter (kv_scale , requires_grad = False )
328
+
329
+ # Remove output_scale from k_proj and v_proj
330
+ k_proj .output_scale = None
331
+ v_proj .output_scale = None
332
+ cleanup_memory ()
333
+
310
334
311
335
def save_quantized_model (
312
336
model : AutoModelForCausalLM ,
0 commit comments