@@ -123,23 +123,28 @@ def get_activation_quantizer(
123123 )
124124 elif qa_mode == "dorefa" :
125125 act_quantizer = dorefa_quantize_activation
126- elif (
127- qa_mode == "max"
128- ): # NOTE Need to be careful using this for activation, particular to 1 sided.
129- act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = False )
130- elif qa_mode == "minmax" :
131- act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = True )
126+
127+ elif "max" in qa_mode :
128+ # NOTE Need to be careful using this for activation, particular to 1 sided.
129+ if "min" in qa_mode :
130+ act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = True )
131+ elif "pertoken" in qa_mode or "perToken" in qa_mode :
132+ act_quantizer = QMaxDynamic (nbits , dim = - 1 )
133+ elif "per_channel" in qa_mode or "perCh" in qa_mode :
134+ act_quantizer = QMaxDynamic (nbits , dim = - 2 )
135+ elif "sym" in qa_mode :
136+ act_quantizer = Qmax (
137+ nbits ,
138+ align_zero = True ,
139+ minmax = False ,
140+ extend_act_range = extend_act_range ,
141+ )
142+ else :
143+ act_quantizer = Qmax (nbits , align_zero = align_zero , minmax = False )
132144 elif qa_mode == "fix" :
133145 act_quantizer = QFixSymmetric (
134146 nbits , init_clip_val = clip_val , align_zero = align_zero
135147 )
136- elif qa_mode == "maxsym" :
137- act_quantizer = Qmax (
138- nbits ,
139- align_zero = True ,
140- minmax = False ,
141- extend_act_range = extend_act_range ,
142- )
143148 elif qa_mode == "pactsym" :
144149 act_quantizer = PACT2Sym (
145150 nbits ,
@@ -179,8 +184,6 @@ def get_activation_quantizer(
179184 perToken = perToken ,
180185 emulate = True ,
181186 )
182- elif qa_mode == "pertokenmax" :
183- act_quantizer = PerTokenMax (nbits )
184187 else :
185188 raise ValueError (f"unrecognized activation quantization mode { qa_mode } " )
186189 else : # swcap-compatible activation quantizers
@@ -3491,6 +3494,42 @@ def __repr__(self):
34913494 return f"{ self .__class__ .__name__ } (num_bits={ self .num_bits } , quantizer=)"
34923495
34933496
3497+ class QMaxDynamic (nn .Module ):
3498+ def __init__ (self , num_bits , dim = - 1 ):
3499+ """
3500+ For per-token or per-channel quantization using abs().max() as scale, usually for activation
3501+ and could be used for Qbmm M2 as well.
3502+ (reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
3503+ dim = -2 -> per-channel
3504+ Zero is aligned so that the levels are symmetric around zero (lossing one level)
3505+ Since the token length is un-known before running, the quantizater can only calculate the
3506+ scales at the run times dynamically, meaning no trainable quantization scales is allowed.
3507+ (unless input seq length is always the same, not just padded to a fixed length.)
3508+ """
3509+ super ().__init__ ()
3510+ self .num_bits = num_bits
3511+ self .levels = 2 ** (self .num_bits - 1 ) - 1
3512+ if isinstance (dim , str ):
3513+ if "perCh" in dim or "per_channel" in dim :
3514+ dim = - 2
3515+ elif "perToken" in dim or "per_token" in dim or "per_Token" in dim :
3516+ dim = - 1
3517+ elif dim in [- 1 , - 2 ]:
3518+ self .reduce_dim = dim
3519+ else :
3520+ raise ValueError (
3521+ f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found { dim } "
3522+ )
3523+
3524+ def forward (self , input_tensor ):
3525+ amax_dim = input_tensor .abs ().max (dim = self .reduce_dim , keepdim = True )[0 ]
3526+ scales = amax_dim .clamp (min = 1e-5 ).div (self .levels )
3527+ return input_tensor .div (scales ).round ().mul (scales )
3528+
3529+ def __repr__ (self ):
3530+ return f"{ self .__class__ .__name__ } (num_bits={ self .num_bits } , quantizer=)"
3531+
3532+
34943533class Qdynamic (nn .Module ):
34953534 def __init__ (
34963535 self ,
0 commit comments