@@ -795,20 +795,23 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
795795 qlin_int .cvs = [a_cv , a_cvn , w_cv ] # TODO remove the need of this?
796796
797797 # prepare smoothQuant scale, = (smQ_a_scale ^ alpha)/(smQ_w_scale ^ (1-alpha) )
798- smq_scale = torch .tensor ([1.0 ], device = tar_dev , dtype = fms_mo_w_dtype )
798+ smoothq_scale = torch .tensor ([1.0 ], device = tar_dev , dtype = fms_mo_w_dtype )
799799 if getattr (fms_mo_qlinear , "smoothq" , False ):
800- smq_a_scale = fms_mo_qlinear .smoothq_act_scale
801- smq_w_scale = (
800+ smoothq_a_scale = fms_mo_qlinear .smoothq_act_scale
801+ smoothq_w_scale = (
802802 fms_mo_qlinear .weight .abs ()
803803 .max (dim = 0 , keepdim = True )[0 ]
804804 .clamp (min = 1e-5 )
805805 )
806- smq_alpha = fms_mo_qlinear .smoothq_alpha
807- if torch .all (smq_a_scale != 0 ).item ():
808- smq_scale = (
809- (smq_a_scale ** smq_alpha / smq_w_scale ** (1.0 - smq_alpha ))
806+ smoothq_alpha = fms_mo_qlinear .smoothq_alpha
807+ if torch .all (smoothq_a_scale != 0 ).item ():
808+ smoothq_scale = (
809+ (
810+ smoothq_a_scale ** smoothq_alpha
811+ / smoothq_w_scale ** (1.0 - smoothq_alpha )
812+ )
810813 .clamp (min = 1e-5 )
811- .to (smq_a_scale .dtype )
814+ .to (smoothq_a_scale .dtype )
812815 )
813816
814817 # could trigger Qw.clipval re-calc for SAWB here, if needed
@@ -845,18 +848,18 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
845848 qlin_int .register_buffer ("input_zp" , input_zero_point )
846849 qlin_int .register_buffer ("w_scale" , w_scale )
847850 qlin_int .register_buffer ("w_zp" , w_zp )
848- qlin_int .register_buffer ("smq_scale " , smq_scale )
851+ qlin_int .register_buffer ("smoothq_scale " , smoothq_scale )
849852
850853 # NOTE:
851854 # 1. Keep W transposed to prevent confusion, hence (W.t()/scale).t()
852855 # 2. only a few quantizer have .dequantize working correctly, e.g. SAWB
853856 # 3. smooth_quant factor is included in the W here, will also include it in the forward
854857 if isinstance (Qw , SAWB ):
855858 Qw .dequantize = False
856- w_int8 = Qw (fms_mo_qlinear .weight .float () * smq_scale )
859+ w_int8 = Qw (fms_mo_qlinear .weight .float () * smoothq_scale )
857860 else :
858861 w_int8 = (
859- torch .round ((fms_mo_qlinear .weight * smq_scale ).t () / w_scale )
862+ torch .round ((fms_mo_qlinear .weight * smoothq_scale ).t () / w_scale )
860863 .clamp (- w_levels / 2 , w_levels / 2 )
861864 .t ()
862865 )
@@ -1323,8 +1326,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
13231326 self .weight .shape [0 ],
13241327 ) # W.shape=[out,in]
13251328
1326- if torch .all (self .smq_scale != 1 ).item ():
1327- x = x .view (re_shape ) / self .smq_scale
1329+ if torch .all (self .smoothq_scale != 1 ).item ():
1330+ x = x .view (re_shape ) / self .smoothq_scale
13281331 else :
13291332 x = x .view (re_shape )
13301333
0 commit comments