1010
1111import bitsandbytes as bnb
1212import bitsandbytes .functional
13- from bitsandbytes .autograd ._functions import get_inverse_transform_indices , undo_layout
13+ from bitsandbytes .autograd ._functions import undo_layout , get_tile_inds
1414from bitsandbytes .optim import GlobalOptimManager
1515from bitsandbytes .utils import OutlierTracer , find_outlier_dims
1616
@@ -306,6 +306,17 @@ def to(self, *args, **kwargs):
306306 return new_param
307307
308308
309+ def maybe_rearrange_weight (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs ):
310+ weight = state_dict .get (f"{ prefix } weight" )
311+ if weight is None :
312+ # if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing
313+ return
314+ weight_format = state_dict .pop (f"{ prefix } weight_format" , "row" )
315+
316+ if weight_format != "row" :
317+ tile_indices = get_tile_inds (weight_format , weight .device )
318+ state_dict [f"{ prefix } weight" ] = undo_layout (weight , tile_indices )
319+
309320
310321class Linear8bitLt (nn .Linear ):
311322 def __init__ (self , input_features , output_features , bias = True , has_fp16_weights = True ,
@@ -322,52 +333,55 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=
322333 self .state .use_pool = True
323334
324335 self .weight = Int8Params (self .weight .data , has_fp16_weights = has_fp16_weights , requires_grad = has_fp16_weights )
336+ self ._register_load_state_dict_pre_hook (maybe_rearrange_weight )
325337
326338 def _save_to_state_dict (self , destination , prefix , keep_vars ):
327- if not self .state .has_fp16_weights and self .state .CB is None and self .state .CxB is not None :
328- # reorder weight layout back from ampere/turing to row
329- reorder_layout = True
330- weight_clone = self .weight .data .clone ()
331- else :
332- reorder_layout = False
339+ super ()._save_to_state_dict (destination , prefix , keep_vars )
333340
334- try :
335- if reorder_layout :
336- self .weight .data = undo_layout (self .state .CxB , self .state .tile_indices )
341+ # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
342+ scb_name = "SCB"
337343
338- super ()._save_to_state_dict (destination , prefix , keep_vars )
344+ # case 1: .cuda was called, SCB is in self.weight
345+ param_from_weight = getattr (self .weight , scb_name )
346+ # case 2: self.init_8bit_state was called, SCB is in self.state
347+ param_from_state = getattr (self .state , scb_name )
348+ # case 3: SCB is in self.state, weight layout reordered after first forward()
349+ layout_reordered = self .state .CxB is not None
339350
340- # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
341- weight_name = "SCB "
351+ key_name = prefix + f" { scb_name } "
352+ format_name = prefix + "weight_format "
342353
343- # case 1: .cuda was called, SCB is in self.weight
344- param_from_weight = getattr (self .weight , weight_name )
345- # case 2: self.init_8bit_state was called, SCB is in self.state
346- param_from_state = getattr (self .state , weight_name )
347-
348- key_name = prefix + f"{ weight_name } "
354+ if not self .state .has_fp16_weights :
349355 if param_from_weight is not None :
350356 destination [key_name ] = param_from_weight if keep_vars else param_from_weight .detach ()
351- elif not self .state .has_fp16_weights and param_from_state is not None :
357+ destination [format_name ] = "row"
358+ elif param_from_state is not None and not layout_reordered :
352359 destination [key_name ] = param_from_state if keep_vars else param_from_state .detach ()
353- finally :
354- if reorder_layout :
355- self .weight .data = weight_clone
360+ destination [format_name ] = "row"
361+ elif param_from_state is not None :
362+ destination [key_name ] = param_from_state if keep_vars else param_from_state .detach ()
363+ destination [format_name ] = self .state .formatB
356364
357365 def _load_from_state_dict (self , state_dict , prefix , local_metadata , strict ,
358366 missing_keys , unexpected_keys , error_msgs ):
359367 super ()._load_from_state_dict (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys ,
360368 error_msgs )
361- for key in unexpected_keys :
369+ unexpected_copy = list (unexpected_keys )
370+
371+ for key in unexpected_copy :
362372 input_name = key [len (prefix ):]
363373 if input_name == "SCB" :
364374 if self .weight .SCB is None :
365- # buffers not yet initialized, can't call them directly without
375+ # buffers not yet initialized, can't access them directly without quantizing first
366376 raise RuntimeError ("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
367377 "not supported. Please call module.cuda() before module.load_state_dict()" )
368378
369379 input_param = state_dict [key ]
370380 self .weight .SCB .copy_ (input_param )
381+
382+ if self .state .SCB is not None :
383+ self .state .SCB = self .weight .SCB
384+
371385 unexpected_keys .remove (key )
372386
373387 def init_8bit_state (self ):
0 commit comments