2121from ..utils import (
2222 USE_PEFT_BACKEND ,
2323 deprecate ,
24+ get_submodule_by_name ,
2425 is_peft_available ,
2526 is_peft_version ,
2627 is_torch_version ,
@@ -1981,10 +1982,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
19811982 in_features = state_dict [lora_A_weight_name ].shape [1 ]
19821983 out_features = state_dict [lora_B_weight_name ].shape [0 ]
19831984
1985+ # Model maybe loaded with different quantization schemes which may flatten the params.
1986+ # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
1987+ # preserve weight shape.
1988+ module_weight_shape = cls ._calculate_module_shape (model = transformer , base_module = module )
1989+
19841990 # This means there's no need for an expansion in the params, so we simply skip.
1985- if tuple (module_weight . shape ) == (out_features , in_features ):
1991+ if tuple (module_weight_shape ) == (out_features , in_features ):
19861992 continue
19871993
1994+ # TODO (sayakpaul): We still need to consider if the module we're expanding is
1995+ # quantized and handle it accordingly if that is the case.
19881996 module_out_features , module_in_features = module_weight .shape
19891997 debug_message = ""
19901998 if in_features > module_in_features :
@@ -2080,13 +2088,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20802088 base_weight_param = transformer_state_dict [base_param_name ]
20812089 lora_A_param = lora_state_dict [f"{ prefix } { k } .lora_A.weight" ]
20822090
2083- if base_weight_param .shape [1 ] > lora_A_param .shape [1 ]:
2091+ # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
2092+ base_module_shape = cls ._calculate_module_shape (model = transformer , base_weight_param_name = base_param_name )
2093+
2094+ if base_module_shape [1 ] > lora_A_param .shape [1 ]:
20842095 shape = (lora_A_param .shape [0 ], base_weight_param .shape [1 ])
20852096 expanded_state_dict_weight = torch .zeros (shape , device = base_weight_param .device )
20862097 expanded_state_dict_weight [:, : lora_A_param .shape [1 ]].copy_ (lora_A_param )
20872098 lora_state_dict [f"{ prefix } { k } .lora_A.weight" ] = expanded_state_dict_weight
20882099 expanded_module_names .add (k )
2089- elif base_weight_param . shape [1 ] < lora_A_param .shape [1 ]:
2100+ elif base_module_shape [1 ] < lora_A_param .shape [1 ]:
20902101 raise NotImplementedError (
20912102 f"This LoRA param ({ k } .lora_A.weight) has an incompatible shape { lora_A_param .shape } . Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
20922103 )
@@ -2098,6 +2109,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20982109
20992110 return lora_state_dict
21002111
2112+ @staticmethod
2113+ def _calculate_module_shape (
2114+ model : "torch.nn.Module" ,
2115+ base_module : "torch.nn.Linear" = None ,
2116+ base_weight_param_name : str = None ,
2117+ ) -> "torch.Size" :
2118+ def _get_weight_shape (weight : torch .Tensor ):
2119+ return weight .quant_state .shape if weight .__class__ .__name__ == "Params4bit" else weight .shape
2120+
2121+ if base_module is not None :
2122+ return _get_weight_shape (base_module .weight )
2123+ elif base_weight_param_name is not None :
2124+ if not base_weight_param_name .endswith (".weight" ):
2125+ raise ValueError (
2126+ f"Invalid `base_weight_param_name` passed as it does not end with '.weight' { base_weight_param_name = } ."
2127+ )
2128+ module_path = base_weight_param_name .rsplit (".weight" , 1 )[0 ]
2129+ submodule = get_submodule_by_name (model , module_path )
2130+ return _get_weight_shape (submodule .weight )
2131+
2132+ raise ValueError ("Either `base_module` or `base_weight_param_name` must be provided." )
2133+
21012134
21022135# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
21032136# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
0 commit comments