@@ -306,6 +306,15 @@ def to(self, *args, **kwargs):
306306 return new_param
307307
308308
309+ def maybe_rearrange_weight (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs ):
310+ weight = state_dict [f"{ prefix } weight" ]
311+ weight_format = state_dict .pop (f"{ prefix } weight_format" , "row" )
312+
313+ if weight_format != "row" :
314+ tile_indices = get_tile_inds (weight_format , weight .device )
315+ state_dict [f"{ prefix } weight" ] = undo_layout (weight , tile_indices )
316+
317+
309318class Linear8bitLt (nn .Linear ):
310319 def __init__ (self , input_features , output_features , bias = True , has_fp16_weights = True ,
311320 memory_efficient_backward = False , threshold = 0.0 , index = None ):
@@ -321,6 +330,7 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=
321330 self .state .use_pool = True
322331
323332 self .weight = Int8Params (self .weight .data , has_fp16_weights = has_fp16_weights , requires_grad = has_fp16_weights )
333+ self ._register_load_state_dict_pre_hook (maybe_rearrange_weight )
324334
325335 def _save_to_state_dict (self , destination , prefix , keep_vars ):
326336 super ()._save_to_state_dict (destination , prefix , keep_vars )
@@ -370,12 +380,6 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
370380 self .state .SCB = self .weight .SCB
371381
372382 unexpected_keys .remove (key )
373- if input_name == "weight_format" :
374- input_param = state_dict [key ]
375- if input_param != "row" :
376- tile_indices = get_tile_inds (input_param , self .weight .device )
377- self .weight .data = self .weight .CB = undo_layout (self .weight .data , tile_indices )
378- unexpected_keys .remove (key )
379383
380384 def init_8bit_state (self ):
381385 self .state .CB = self .weight .CB
0 commit comments