2121from  ..utils  import  (
2222    USE_PEFT_BACKEND ,
2323    deprecate ,
24+     get_submodule_by_name ,
2425    is_peft_available ,
2526    is_peft_version ,
2627    is_torch_version ,
@@ -1981,16 +1982,12 @@ def _maybe_expand_transformer_param_shape_or_error_(
19811982                in_features  =  state_dict [lora_A_weight_name ].shape [1 ]
19821983                out_features  =  state_dict [lora_B_weight_name ].shape [0 ]
19831984
1985+                 # Model maybe loaded with different quantization schemes which may flatten the params. 
1986+                 # `bitsandbytes`, for example, flatten the weights when using 4bit. 
1987+                 module_weight_shape  =  cls ._calculate_module_shape (model = transformer , base_module = module )
1988+ 
19841989                # This means there's no need for an expansion in the params, so we simply skip. 
1985-                 module_weight_shape  =  module_weight .shape 
1986-                 expansion_shape  =  (out_features , in_features )
1987-                 quantization_config  =  getattr (transformer , "quantization_config" , None )
1988-                 if  quantization_config  and  quantization_config .quant_method  ==  "bitsandbytes" :
1989-                     if  quantization_config .load_in_4bit :
1990-                         expansion_shape  =  torch .Size (expansion_shape ).numel ()
1991-                         expansion_shape  =  ((expansion_shape  +  1 ) //  2 , 1 )
1992- 
1993-                 if  tuple (module_weight_shape ) ==  expansion_shape :
1990+                 if  tuple (module_weight_shape ) ==  (out_features , in_features ):
19941991                    continue 
19951992
19961993                # TODO (sayakpaul): We still need to consider if the module we're expanding is 
@@ -2090,22 +2087,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20902087            base_weight_param  =  transformer_state_dict [base_param_name ]
20912088            lora_A_param  =  lora_state_dict [f"{ prefix } { k }  ]
20922089
2093-             # TODO (sayakpaul): Handle the cases when we actually need to expand. 
2094-             base_out_feature_shape  =  base_weight_param .shape [1 ]
2095-             lora_A_out_feature_shape  =  lora_A_param .shape [1 ]
2096-             quantization_config  =  getattr (transformer , "quantization_config" , None )
2097-             if  quantization_config  and  quantization_config .quant_method  ==  "bitsandbytes" :
2098-                 if  quantization_config .load_in_4bit :
2099-                     lora_A_out_feature_shape  =  lora_A_param .shape .numel ()
2100-                     lora_A_out_feature_shape  =  ((lora_A_out_feature_shape  +  1 ) //  2 , 1 )[1 ]
2090+             # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. 
2091+             base_module_shape  =  cls ._calculate_module_shape (model = transformer , base_weight_param_name = base_param_name )
21012092
2102-             if  base_out_feature_shape  >  lora_A_out_feature_shape :
2093+             if  base_module_shape [ 1 ]  >  lora_A_param . shape [ 1 ] :
21032094                shape  =  (lora_A_param .shape [0 ], base_weight_param .shape [1 ])
21042095                expanded_state_dict_weight  =  torch .zeros (shape , device = base_weight_param .device )
21052096                expanded_state_dict_weight [:, : lora_A_param .shape [1 ]].copy_ (lora_A_param )
21062097                lora_state_dict [f"{ prefix } { k }  ] =  expanded_state_dict_weight 
21072098                expanded_module_names .add (k )
2108-             elif  lora_A_out_feature_shape  <  lora_A_out_feature_shape :
2099+             elif  base_module_shape [ 1 ]  <  lora_A_param . shape [ 1 ] :
21092100                raise  NotImplementedError (
21102101                    f"This LoRA param ({ k } { lora_A_param .shape }  
21112102                )
@@ -2117,6 +2108,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
21172108
21182109        return  lora_state_dict 
21192110
2111+     @staticmethod  
2112+     def  _calculate_module_shape (
2113+         model : "torch.nn.Module" ,
2114+         base_module : "torch.nn.Linear"  =  None ,
2115+         base_weight_param_name : str  =  None ,
2116+     ) ->  "torch.Size" :
2117+         def  _get_weight_shape (weight : torch .Tensor ):
2118+             return  weight .quant_state .shape  if  weight .__class__ .__name__  ==  "Params4bit"  else  weight .shape 
2119+ 
2120+         if  base_module  is  not None :
2121+             return  _get_weight_shape (base_module .weight )
2122+         elif  base_weight_param_name  is  not None :
2123+             module_path  =  (
2124+                 base_weight_param_name .rsplit (".weight" , 1 )[0 ]
2125+                 if  base_weight_param_name .endswith (".weight" )
2126+                 else  base_weight_param_name 
2127+             )
2128+             submodule  =  get_submodule_by_name (model , module_path )
2129+             return  _get_weight_shape (submodule .weight )
2130+ 
2131+         raise  ValueError ("Either `base_module` or `base_weight_param_name` must be provided." )
2132+ 
21202133
21212134# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially 
21222135# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. 
0 commit comments