File tree Expand file tree Collapse file tree 1 file changed +9
-5
lines changed
Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -223,7 +223,10 @@ def __init__(
223223 self .smoothq = qcfg .get ("smoothq" , False )
224224 if self .smoothq :
225225 self .register_buffer ("smoothq_act_scale" , torch .zeros (w_size [1 ]))
226- self .smoothq_alpha = qcfg .get ("smoothq_alpha" , 0.5 )
226+ self .register_buffer (
227+ "smoothq_alpha" ,
228+ torch .tensor ([qcfg .get ("smoothq_alpha" , 0.5 )], dtype = torch .float32 ),
229+ )
227230
228231 def forward (self , x ):
229232 """
@@ -335,11 +338,12 @@ def get_smoothq_scale(self, x):
335338 smoothq_scale = torch .tensor ([1.0 ]).to (x .dtype ).to (x .device )
336339 else :
337340 weight_scale = self .weight .abs ().max (dim = 0 , keepdim = True )[0 ].clamp (min = 1e-5 )
341+ if isinstance (self .smoothq_alpha , torch .Tensor ):
342+ alpha = self .smoothq_alpha .item ()
343+ else :
344+ alpha = self .smoothq_alpha
338345 smoothq_scale = (
339- (
340- self .smoothq_act_scale .pow (self .smoothq_alpha )
341- / weight_scale .pow (1.0 - self .smoothq_alpha )
342- )
346+ (self .smoothq_act_scale .pow (alpha ) / weight_scale .pow (1.0 - alpha ))
343347 .clamp (min = 1e-5 )
344348 .to (x .dtype )
345349 )
You can’t perform that action at this time.
0 commit comments