@@ -259,23 +259,20 @@ def __init__(
259259 dtype = self .float_type ,
260260 ).to (device ),
261261 )
262- self .scales = self .scales .T
263262 self .register_buffer (
264263 "qweight" ,
265264 torch .zeros (
266265 (math .ceil (in_features / self .n_pack ), out_features ),
267266 dtype = self .compression_dtype ,
268267 ).to (device ),
269268 )
270- self .qweight = self .qweight .T
271269 self .register_buffer (
272270 "qzeros" ,
273271 torch .zeros (
274272 (math .ceil (self .in_features / self .groupsize ), math .ceil (self .out_features / self .n_pack )),
275273 dtype = self .compression_dtype ,
276274 ).to (device ),
277275 )
278- self .qzeros = self .qzeros .T
279276 self .register_buffer ("bias" , torch .zeros (self .out_features , dtype = self .float_type ).to (device ))
280277 else :
281278 self .compression_dtype = compression_dtype
@@ -329,6 +326,10 @@ def __init__(
329326 self .g_idx = None
330327
331328 def pack (self , int_weight , scale , zp , bias , g_idx = None ):
329+ if self .use_optimum_format :
330+ self .scales = self .scales .T
331+ self .qweight = self .qweight .T
332+ self .qzeros = self .qzeros .T
332333 int_weight = int_weight .to (self .device )
333334 if self .use_optimum_format and zp is None :
334335 # to avoid overflow
@@ -468,12 +469,13 @@ def recover(self):
468469 return fp32_weight
469470
470471 def forward (self , input ):
471- weight = self .recover ()
472- device = self .scales .device
473- if weight .dtype == torch .float16 and device .type == "cpu" :
474- weight = weight .float ()
475- self .bias = self .bias .float () if self .bias is not None else None
476- if level == DEBUG :
472+ if not hasattr (self , "weight" ):
473+ weight = self .recover ()
474+ device = self .scales .device
475+ if weight .dtype == torch .float16 and device .type == "cpu" :
476+ weight = weight .float ()
477+ self .bias = self .bias .float () if self .bias is not None else None
478+ if True : # keep reusing self.weight due to recover is too slow.
477479 if not hasattr (self , "weight" ):
478480 self .weight = weight
479481 input = input .type (self .weight .dtype )
0 commit comments