@@ -742,7 +742,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
742742 for a_or_w in ["num_bits_feature" , "num_bits_weight" ]
743743 ), "Please check nbits setting!"
744744
745- target_device = kwargs .get (
745+ tar_dev = kwargs .get (
746746 "target_device" ,
747747 kwargs .get ("device" , next (fms_mo_qlinear .parameters ()).device ),
748748 )
@@ -751,14 +751,15 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
751751 fms_mo_qlinear .in_features ,
752752 fms_mo_qlinear .out_features ,
753753 bias = fms_mo_qlinear .bias is not None ,
754- device = target_device ,
754+ device = tar_dev ,
755755 )
756756 # Make sure to register an Op for integer matmul, could be real INT matmul or emulation
757757 qcfg = getattr (fms_mo_qlinear , "qcfg" , {})
758758 qlin_int .use_int_kernel = kwargs .get (
759759 "use_int_kernel" , qcfg .get ("use_int_kernel" , "cutlass" )
760760 )
761761 qlin_int .usePTnativeQfunc = kwargs .get ("use_PT_native_Qfunc" , False )
762+ qlin_int .useDynMaxQfunc = kwargs .get ("use_dynamic_max_act_Qfunc" , False )
762763 qlin_int .max_acc_bits = kwargs .get ("max_acc_bits" , 32 )
763764 qlin_int .accminmax = (
764765 - (1 << (qlin_int .max_acc_bits - 1 )),
@@ -773,34 +774,48 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
773774 with torch .no_grad ():
774775 Qa = fms_mo_qlinear .quantize_feature
775776 Qw = fms_mo_qlinear .quantize_weight
776- a_cv , a_cvn = Qa .clip_val .item (), Qa .clip_valn .item ()
777777 w_cv = Qw .clip_val .item ()
778+ if qlin_int .useDynMaxQfunc in [- 1 , - 2 ]: # [-1, -2] indicates reduce_dim
779+ # dynamic Qmax has no clipvals, reg fake ones, won't be used in real calc
780+ Qa .register_buffer ("clip_val" , torch .tensor (8.0 , device = tar_dev ))
781+ Qa .register_buffer ("clip_valn" , torch .tensor (- 8.0 , device = tar_dev ))
782+ a_cv , a_cvn = Qa .clip_val .item (), Qa .clip_valn .item ()
783+ # Store original cv_a and cv_w (in python floats, not tensors), and sq scales
784+ # for later use (probably not necessary)
785+ qlin_int .cvs = [a_cv , a_cvn , w_cv ]
778786 # NOTE: Keep w transposed to prevent confusion
779787 Qw .dequantize = False
780- w_int8 = Qw (
781- fms_mo_qlinear .weight .float ()
782- ) # Qw.clipval should have been updated after this
788+ # trigger Qw.clipval re-calc for SAWB (if needed)
789+ w_int8 = Qw (fms_mo_qlinear .weight .float ())
783790 qlin_int .weight = nn .Parameter (
784791 w_int8 .to (torch .int8 ), requires_grad = False
785792 ) # NOTE: may need INT W stored as FP in some cases
786793
787- if qlin_int .usePTnativeQfunc :
794+ if qlin_int .useDynMaxQfunc in [- 1 , - 2 ]:
795+ input_scale = torch .tensor (1.0 , device = tar_dev )
796+ input_zero_point = torch .tensor (128 , dtype = torch .int , device = tar_dev )
797+ w_scale = torch .tensor (
798+ [w_cv * 2 / (2 ** qlin_int .nbits_w - 2 )], device = tar_dev
799+ )
800+ elif qlin_int .usePTnativeQfunc :
788801 input_scale = torch .tensor (
789- [(a_cv - a_cvn ) / (2 ** qlin_int .nbits_a - 1 )], device = target_device
802+ [(a_cv - a_cvn ) / (2 ** qlin_int .nbits_a - 1 )], device = tar_dev
790803 )
791804 input_zero_point = torch .round (- a_cvn / input_scale ).to (torch .int )
792- w_scale = torch .tensor ([w_cv * 2 / (2 ** qlin_int .nbits_w - 2 )])
805+ w_scale = torch .tensor (
806+ [w_cv * 2 / (2 ** qlin_int .nbits_w - 2 )], device = tar_dev
807+ )
793808 else :
794809 # fms_mo formula is a bit different from conventional PT formula
795810 quant_scale = (2 ** qlin_int .nbits_a - 1 ) / torch .tensor (
796- [a_cv - a_cvn ], device = target_device
811+ [a_cv - a_cvn ], device = tar_dev
797812 )
798813 quant_stepsize = 1.0 / quant_scale
799814 quant_zero_point = torch .round (a_cvn * quant_scale )
800815 input_scale = quant_stepsize
801816 input_zero_point = - quant_zero_point
802817 quant_w_scale = (2 ** qlin_int .nbits_a - 2 ) / torch .tensor (
803- [w_cv * 2 ], device = target_device
818+ [w_cv * 2 ], device = tar_dev
804819 )
805820 w_scale = 1.0 / quant_w_scale
806821 qlin_int .register_buffer ("quant_scale" , quant_scale )
@@ -812,9 +827,6 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
812827 qlin_int .register_buffer ("input_zp" , input_zero_point )
813828 qlin_int .register_buffer ("w_scale" , w_scale )
814829 qlin_int .register_buffer ("w_zp" , w_zp )
815- # Store original cv_a and cv_w (in python floats, not tensors), and sq scales
816- # for later verification
817- qlin_int .cvs = [Qa .clip_val .item (), Qa .clip_valn .item (), Qw .clip_val .item ()]
818830
819831 corr_term = (
820832 (input_zero_point - 128 )
@@ -836,17 +848,14 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
836848 qlin_int .register_buffer ("bias" , - corr_term .to (fms_mo_w_dtype ))
837849 qlin_int .org_model_has_bias = False
838850
839- qlin_int .register_buffer ("Qa_clip_val" , Qa .clip_val .detach ())
840- qlin_int .register_buffer (
841- "Qa_clip_valn" , Qa .clip_valn .detach ()
842- ) # TODO: case for PACT?
843- qlin_int .register_buffer (
844- "Qw_clip_val" , Qw .clip_val .detach ()
845- ) # asym W quantizer may have clipvaln
851+ # redundant variables to be cleaned up
852+ # qlin_int.register_buffer("Qa_clip_val", Qa.clip_val.detach())
853+ # qlin_int.register_buffer("Qa_clip_valn", Qa.clip_valn.detach())
854+ # qlin_int.register_buffer("Qw_clip_val", Qw.clip_val.detach())
846855
847856 qlin_int .set_matmul_op ()
848857
849- return qlin_int .to (target_device )
858+ return qlin_int .to (tar_dev )
850859
851860 @classmethod
852861 def from_torch_iW (cls , nnlin_iW , prec , a_cv , a_cvn , w_cv , zero_shift , ** kwargs ):
@@ -988,25 +997,15 @@ def qa_raw_qfunc(self, x):
988997 """
989998 Quantizes the input tensor x to 8-bit integer values using raw formula, slower if not
990999 torch.compiled
991-
992- Args:
993- x (Tensor): Input tensor to be quantized.
994-
995- Returns:
996- Tensor: Quantized tensor with values in the range [-128, 127].
9971000 """
9981001 x = torch .clamp ((x / self .input_scale + self .input_zp - 128 ).round (), - 128 , 127 )
9991002 return x .to (torch .int8 )
10001003
10011004 def qa_fmo_mo_qfunc (self , x ):
10021005 """
1003- Quantizes the input tensor x to 8-bit integer values.
1004-
1005- Args:
1006- x (Tensor): Input tensor to be quantized.
1007-
1008- Returns:
1009- Tensor: Quantized tensor with values in the range [-128, 127].
1006+ Quantizes the input tensor x to 8-bit integer values. Note that old fms-mo formula clamps
1007+ before rounds, as opposed to typical torch formula that rounds before clamps.
1008+ (See qa_raw_qfunc() above.)
10101009 """
10111010 x = (
10121011 torch .round (
@@ -1017,6 +1016,21 @@ def qa_fmo_mo_qfunc(self, x):
10171016 )
10181017 return x .to (torch .int8 )
10191018
1019+ def qa_dynamic_max_qfunc (self , x ):
1020+ """
1021+ Symmetric dynamic quantizer, same as QDynMax, which allows per-token or per-channel.
1022+ This quantizer will not use self.input_scale but instead will update it every time.
1023+ NOTE
1024+ 1. self.input_scale.shape should be (x.shape[-2], ) if reduce_dim == -1 and (, x.shape[-1])
1025+ for reduce_dim == -2.
1026+ 2. input_scale should be be broadcasted correctly together with W_scale (e.g. if per-Ch) at
1027+ final output step, i.e. imm_out*(a_scale*w_scale)*...
1028+ """
1029+ amax = x .abs ().max (dim = self .useDynMaxQfunc , keepdim = True )[0 ]
1030+ levels = 2 ** (self .nbits_a - 1 ) - 1
1031+ self .input_scale = amax .clamp (min = 1e-5 ).div (levels )
1032+ return torch .round (x / self .input_scale ).to (torch .int8 )
1033+
10201034 def iaddmm_int (self , bias , m1 , m2 ):
10211035 """
10221036 Performs integer matrix multiplication with optional addition of a bias term.
@@ -1034,7 +1048,9 @@ def iaddmm_int(self, bias, m1, m2):
10341048 The result of the integer matrix multiplication with the bias added.
10351049 """
10361050
1037- if self .usePTnativeQfunc :
1051+ if self .useDynMaxQfunc in [- 1 , - 2 ]:
1052+ m1 = self .qa_dynamic_max_qfunc (m1 )
1053+ elif self .usePTnativeQfunc :
10381054 m1 = self .qa_raw_qfunc (m1 )
10391055 else :
10401056 m1 = self .qa_fmo_mo_qfunc (m1 )
0 commit comments