@@ -1983,7 +1983,8 @@ def _maybe_expand_transformer_param_shape_or_error_(
19831983                out_features  =  state_dict [lora_B_weight_name ].shape [0 ]
19841984
19851985                # Model maybe loaded with different quantization schemes which may flatten the params. 
1986-                 # `bitsandbytes`, for example, flatten the weights when using 4bit. 
1986+                 # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models 
1987+                 # preserve weight shape. 
19871988                module_weight_shape  =  cls ._calculate_module_shape (model = transformer , base_module = module )
19881989
19891990                # This means there's no need for an expansion in the params, so we simply skip. 
@@ -2120,11 +2121,11 @@ def _get_weight_shape(weight: torch.Tensor):
21202121        if  base_module  is  not None :
21212122            return  _get_weight_shape (base_module .weight )
21222123        elif  base_weight_param_name  is  not None :
2123-             module_path   =  ( 
2124-                 base_weight_param_name . rsplit ( ".weight" ,  1 )[ 0 ] 
2125-                 if   base_weight_param_name . endswith ( ". weight" ) 
2126-                 else   base_weight_param_name 
2127-             ) 
2124+             if   not   base_weight_param_name . endswith ( ".weight" ): 
2125+                 raise   ValueError ( 
2126+                      f"Invalid ` base_weight_param_name` passed as it does not end with '. weight'  { base_weight_param_name = } ." 
2127+                 ) 
2128+             module_path   =   base_weight_param_name . rsplit ( ".weight" ,  1 )[ 0 ] 
21282129            submodule  =  get_submodule_by_name (model , module_path )
21292130            return  _get_weight_shape (submodule .weight )
21302131
0 commit comments