2121from  ..utils  import  (
2222    USE_PEFT_BACKEND ,
2323    deprecate ,
24+     get_submodule_by_name ,
2425    is_peft_available ,
2526    is_peft_version ,
2627    is_torch_version ,
@@ -1981,10 +1982,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
19811982                in_features  =  state_dict [lora_A_weight_name ].shape [1 ]
19821983                out_features  =  state_dict [lora_B_weight_name ].shape [0 ]
19831984
1985+                 # Model maybe loaded with different quantization schemes which may flatten the params. 
1986+                 # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models 
1987+                 # preserve weight shape. 
1988+                 module_weight_shape  =  cls ._calculate_module_shape (model = transformer , base_module = module )
1989+ 
19841990                # This means there's no need for an expansion in the params, so we simply skip. 
1985-                 if  tuple (module_weight . shape ) ==  (out_features , in_features ):
1991+                 if  tuple (module_weight_shape ) ==  (out_features , in_features ):
19861992                    continue 
19871993
1994+                 # TODO (sayakpaul): We still need to consider if the module we're expanding is 
1995+                 # quantized and handle it accordingly if that is the case. 
19881996                module_out_features , module_in_features  =  module_weight .shape 
19891997                debug_message  =  "" 
19901998                if  in_features  >  module_in_features :
@@ -2080,13 +2088,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20802088            base_weight_param  =  transformer_state_dict [base_param_name ]
20812089            lora_A_param  =  lora_state_dict [f"{ prefix } { k }  ]
20822090
2083-             if  base_weight_param .shape [1 ] >  lora_A_param .shape [1 ]:
2091+             # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. 
2092+             base_module_shape  =  cls ._calculate_module_shape (model = transformer , base_weight_param_name = base_param_name )
2093+ 
2094+             if  base_module_shape [1 ] >  lora_A_param .shape [1 ]:
20842095                shape  =  (lora_A_param .shape [0 ], base_weight_param .shape [1 ])
20852096                expanded_state_dict_weight  =  torch .zeros (shape , device = base_weight_param .device )
20862097                expanded_state_dict_weight [:, : lora_A_param .shape [1 ]].copy_ (lora_A_param )
20872098                lora_state_dict [f"{ prefix } { k }  ] =  expanded_state_dict_weight 
20882099                expanded_module_names .add (k )
2089-             elif  base_weight_param . shape [1 ] <  lora_A_param .shape [1 ]:
2100+             elif  base_module_shape [1 ] <  lora_A_param .shape [1 ]:
20902101                raise  NotImplementedError (
20912102                    f"This LoRA param ({ k } { lora_A_param .shape }  
20922103                )
@@ -2098,6 +2109,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20982109
20992110        return  lora_state_dict 
21002111
2112+     @staticmethod  
2113+     def  _calculate_module_shape (
2114+         model : "torch.nn.Module" ,
2115+         base_module : "torch.nn.Linear"  =  None ,
2116+         base_weight_param_name : str  =  None ,
2117+     ) ->  "torch.Size" :
2118+         def  _get_weight_shape (weight : torch .Tensor ):
2119+             return  weight .quant_state .shape  if  weight .__class__ .__name__  ==  "Params4bit"  else  weight .shape 
2120+ 
2121+         if  base_module  is  not None :
2122+             return  _get_weight_shape (base_module .weight )
2123+         elif  base_weight_param_name  is  not None :
2124+             if  not  base_weight_param_name .endswith (".weight" ):
2125+                 raise  ValueError (
2126+                     f"Invalid `base_weight_param_name` passed as it does not end with '.weight' { base_weight_param_name = }  
2127+                 )
2128+             module_path  =  base_weight_param_name .rsplit (".weight" , 1 )[0 ]
2129+             submodule  =  get_submodule_by_name (model , module_path )
2130+             return  _get_weight_shape (submodule .weight )
2131+ 
2132+         raise  ValueError ("Either `base_module` or `base_weight_param_name` must be provided." )
2133+ 
21012134
21022135# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially 
21032136# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. 
0 commit comments