33
44import torch
55import torch .nn .functional as F
6- from bitorch_engine .utils .quant_operators import nv_tensor_quant , gptq_stype_unpacking
6+ from bitorch_engine .utils .quant_operators import nv_tensor_quant , gptq_style_unpacking , gptq_style_zeros_packing
77from bitorch_engine .functions .cuda import tensor_to_packed_uint8 , unpack_uint8_tensor
88
99
@@ -327,6 +327,39 @@ def init_weight(weight: torch.Tensor, cls: Type[torch.nn.Parameter]=torch.nn.Par
327327 return weight , scale_w
328328
329329
330+ def update_zeros (qweight , w , norm_grad , step_size , z_unpacked = None ):
331+ """
332+ Updates the zeros attribute of the qweight object based on its layer type.
333+
334+ Args:
335+ qweight: An object containing quantization parameters, including the zeros attribute.
336+ w: Weight tensor.
337+ norm_grad: Normalized gradient tensor.
338+ step_size: Step size for updating zeros.
339+ z_unpacked: Optional unpacked zeros tensor for specific layer types.
340+ """
341+ if qweight .layer_type == 2 : # MBWQ-layer
342+ q_perm = qweight .q_perm .unsqueeze (1 ).repeat (1 , w .size (1 )).long ()
343+ zeros_grad = torch .gather (norm_grad , dim = 0 , index = q_perm )
344+ qweight .zeros .add_ (
345+ step_size * zeros_grad .view (- 1 , w .size (0 ) // qweight .scales .size (0 ), qweight .scales .size (- 1 )).mean (1 )
346+ )
347+ del zeros_grad
348+ elif qweight .layer_type == 1 and qweight .g_idx is not None : # MPQ-layer & GPTQ
349+ zeros_unpack = z_unpacked [qweight .g_idx .long ()]
350+ zeros_unpack .add_ (step_size * norm_grad )
351+
352+ g_idx = qweight .g_idx .long ()
353+ perm = torch .argsort (g_idx , dim = 0 )
354+ zeros = zeros_unpack [perm , :].view (- 1 , w .size (0 ) // qweight .scales .size (0 ), qweight .scales .size (- 1 )).mean (1 )
355+
356+ # pack to qzeros
357+ qweight .zeros = gptq_style_zeros_packing (zeros , qweight .w_bit , zeros .size (- 1 ), qweight .group_size )
358+ else :
359+ raise NotImplementedError (
360+ "qweight.layer_type: '{}' has not been supported yet." .format (str (qweight .layer_type )))
361+
362+
330363def qweight_update_fn (qweight : torch .nn .Parameter , exp_avg_s : torch .Tensor = None , exp_avg_l : torch .Tensor = None ,
331364 step : torch .Tensor = None , lr :float = 1e-4 , weight_decay :float = 0.0 , beta1 :float = 0.99 ,
332365 beta2 :float = 0.9999 , eps : float = 1e-6 , dtype = torch .half , correct_bias = None , projector = None ,
@@ -452,7 +485,9 @@ def qweight_update_fn(qweight: torch.nn.Parameter, exp_avg_s: torch.Tensor=None,
452485 elif isinstance (qweight , MPQWeightParameter ):
453486
454487 # unpack qweight
455- w = gptq_stype_unpacking (qweight ).to (dtype )
488+ w , z_unpacked = gptq_style_unpacking (qweight )
489+ w = w .to (dtype )
490+ z_unpacked = z_unpacked .to (dtype )
456491
457492 # Decay the first and second moment running average coefficient
458493 # In-place operations to update the averages at the same time
@@ -475,11 +510,19 @@ def qweight_update_fn(qweight: torch.nn.Parameter, exp_avg_s: torch.Tensor=None,
475510
476511 w .add_ (norm_grad , alpha = - step_size )
477512
478- if weight_decay > 0.0 :
479- w .add_ (w , alpha = (- lr * weight_decay ))
513+ # ===== update zeros ===== #
514+ # We are not performing the gradient update for 'zeros' in the conventional way.
515+ # Instead, we are making a special handling here because, although 'zeros' is of fp data type,
516+ # in our optimization scenario, it is tied to the updates of 'qweight'.
517+ # Moreover, 'zeros' is not always updated but interacts with 'qweight' at a relatively sparse frequency.
518+ # If we were to update 'zeros' as a regular fp-parameter, it might not allow us the flexibility
519+ # to design these interactions conveniently.
520+ # Considering this is a beta version, future updates and adjustments might be possible.
521+ if step % 5 == 0 :
522+ update_zeros (qweight , w , norm_grad , step_size , z_unpacked )
480523
481524 # pack fp weight back to Q-weight and update qweight data
482- qweight .data = pack_fp_weight (w , qweight )
525+ qweight .data = pack_fp_weight (w , qweight , z_unpacked )
483526
484527 # manually empty cuda cache.
485528 del w
0 commit comments