@@ -123,12 +123,24 @@ def get_activation_quantizer(
123123 )
124124 elif qa_mode == "dorefa" :
125125 act_quantizer = dorefa_quantize_activation
126- elif (
127- qa_mode == "max"
128- ): # NOTE Need to be careful using this for activation, particular to 1 sided.
129- act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = False )
130- elif qa_mode == "minmax" :
131- act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = True )
126+
127+ elif "max" in qa_mode :
128+ # NOTE Need to be careful using this for activation, particular to 1 sided.
129+ if "min" in qa_mode :
130+ act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = True )
131+ elif "pertoken" in qa_mode or "perToken" in qa_mode :
132+ act_quantizer = QMaxDynamic (nbits , dim = - 1 )
133+ elif "per_channel" in qa_mode or "perCh" in qa_mode :
134+ act_quantizer = QMaxDynamic (nbits , dim = - 2 )
135+ elif "sym" in qa_mode :
136+ act_quantizer = Qmax (
137+ nbits ,
138+ align_zero = True ,
139+ minmax = False ,
140+ extend_act_range = extend_act_range ,
141+ )
142+ else :
143+ act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = False )
132144 elif qa_mode == "fix" :
133145 act_quantizer = QFixSymmetric (
134146 nbits , init_clip_val = clip_val , align_zero = align_zero
@@ -140,13 +152,7 @@ def get_activation_quantizer(
140152 minmax = False ,
141153 extend_act_range = extend_act_range ,
142154 )
143- elif qa_mode == "pactsym" :
144- act_quantizer = PACT2Sym (
145- nbits ,
146- init_clip_val = clip_val ,
147- dequantize = True ,
148- inplace = False ,
149- )
155+
150156 elif qa_mode == "pactsym+" :
151157 act_quantizer = PACTplusSym (
152158 nbits ,
@@ -179,8 +185,6 @@ def get_activation_quantizer(
179185 perToken = perToken ,
180186 emulate = True ,
181187 )
182- elif qa_mode == "pertokenmax" :
183- act_quantizer = PerTokenMax (nbits )
184188 else :
185189 raise ValueError (f"unrecognized activation quantization mode { qa_mode } " )
186190 else : # swcap-compatible activation quantizers
@@ -3488,6 +3492,42 @@ def __repr__(self):
34883492 return f"{ self .__class__ .__name__ } (num_bits={ self .num_bits } , quantizer=)"
34893493
34903494
3495+ class QMaxDynamic (nn .Module ):
3496+ def __init__ (self , num_bits , dim = - 1 ):
3497+ """
3498+ For per-token or per-channel quantization using abs().max() as scale, usually for activation
3499+ and could be used for Qbmm M2 as well.
3500+ (reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
3501+ dim = -2 -> per-channel
3502+ Zero is aligned so that the levels are symmetric around zero (lossing one level)
3503+ Since the token length is un-known before running, the quantizater can only calculate the
3504+ scales at the run times dynamically, meaning no trainable quantization scales is allowed.
3505+ (unless input seq length is always the same, not just padded to a fixed length.)
3506+ """
3507+ super ().__init__ ()
3508+ self .num_bits = num_bits
3509+ self .levels = 2 ** (self .num_bits - 1 ) - 1
3510+ if isinstance (dim , str ):
3511+ if "perCh" in dim or "per_channel" in dim :
3512+ dim = - 2
3513+ elif "perToken" in dim or "per_token" in dim or "per_Token" in dim :
3514+ dim = - 1
3515+ elif dim in [- 1 , - 2 ]:
3516+ self .reduce_dim = dim
3517+ else :
3518+ raise ValueError (
3519+ f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found { dim } "
3520+ )
3521+
3522+ def forward (self , input_tensor ):
3523+ amax_dim = input_tensor .abs ().max (dim = self .reduce_dim , keepdim = True )[0 ]
3524+ scales = amax_dim .clamp (min = 1e-5 ).div (self .levels )
3525+ return input_tensor .div (scales ).round ().mul (scales )
3526+
3527+ def __repr__ (self ):
3528+ return f"{ self .__class__ .__name__ } (num_bits={ self .num_bits } , quantizer=)"
3529+
3530+
34913531class Qdynamic (nn .Module ):
34923532 def __init__ (
34933533 self ,
0 commit comments