Skip to content

Commit 42970ee

Browse files
committed
improve log messages
1 parent 6ef2c8b commit 42970ee

File tree

1 file changed

+39
-6
lines changed

1 file changed

+39
-6
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,10 +1822,17 @@ def prune_state_dict_(state_dict):
18221822
}
18231823

18241824
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1825-
self._maybe_expand_transformer_param_shape_(
1825+
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
18261826
transformer, transformer_lora_state_dict, transformer_norm_state_dict
18271827
)
18281828

1829+
if has_param_with_expanded_shape:
1830+
logger.info(
1831+
"The LoRA weights contain parameters that have different shapes that expected by the transformer. "
1832+
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
1833+
"To get a comprehensive list of parameter names that were modified, enable debug logging."
1834+
)
1835+
18291836
if len(transformer_lora_state_dict) > 0:
18301837
self.load_lora_into_transformer(
18311838
transformer_lora_state_dict,
@@ -1931,10 +1938,13 @@ def _load_norm_into_transformer(
19311938
for key in state_dict.keys():
19321939
overwritten_layers_state_dict[key] = transformer_state_dict[key]
19331940

1934-
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
19351941
logger.info(
1936-
"Normalization layers in LoRA state dict can only be loaded if fused directly in the transformer. Calls to `.fuse_lora()` will only affect the LoRA layers and not the normalization layers."
1942+
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
1943+
'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
1944+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
19371945
)
1946+
1947+
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
19381948
transformer.load_state_dict(state_dict, strict=False)
19391949

19401950
return overwritten_layers_state_dict
@@ -2175,7 +2185,9 @@ def fuse_lora(
21752185
"""
21762186
if len(self._transformer_norm_layers.keys()) > 0:
21772187
logger.info(
2178-
"Normalization layers cannot be loaded without fusing. Calls to `.fuse_lora()` will only affect the actual LoRA layers."
2188+
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
2189+
'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
2190+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
21792191
)
21802192

21812193
super().fuse_lora(
@@ -2202,13 +2214,13 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
22022214
super().unfuse_lora(components=components)
22032215

22042216
@classmethod
2205-
def _maybe_expand_transformer_param_shape_(
2217+
def _maybe_expand_transformer_param_shape_or_error_(
22062218
cls,
22072219
transformer: torch.nn.Module,
22082220
lora_state_dict=None,
22092221
norm_state_dict=None,
22102222
prefix=None,
2211-
):
2223+
) -> bool:
22122224
state_dict = {}
22132225
if lora_state_dict is not None:
22142226
state_dict.update(lora_state_dict)
@@ -2231,6 +2243,8 @@ def get_submodule(module, name):
22312243
return module
22322244

22332245
# Expand transformer parameter shapes if they don't match lora
2246+
has_param_with_shape_update = False
2247+
22342248
for name, module in transformer.named_modules():
22352249
if isinstance(module, torch.nn.Linear):
22362250
module_weight = module.weight.data
@@ -2252,6 +2266,23 @@ def get_submodule(module, name):
22522266
if tuple(module_weight.shape) == (out_features, in_features):
22532267
continue
22542268

2269+
module_out_features, module_in_features = module_weight.shape
2270+
if out_features < module_out_features or in_features < module_in_features:
2271+
raise NotImplementedError(
2272+
f"Only LoRAs with input/output features higher than the current modules' input/output features "
2273+
f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which "
2274+
f"are lower than {module_in_features=} and {module_out_features=}. If you require support for "
2275+
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
2276+
)
2277+
2278+
logger.debug(
2279+
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
2280+
f"checkpoint contains higher number of features than expected. The number of input_features will be "
2281+
f"expanded from {module_in_features} to {in_features}, and the number of output features will be "
2282+
f"expanded from {module_out_features} to {out_features}."
2283+
)
2284+
2285+
has_param_with_shape_update = True
22552286
parent_module_name = ".".join(name_split[:-1])
22562287
current_module_name = name_split[-1]
22572288
parent_module = get_submodule(transformer, parent_module_name)
@@ -2277,6 +2308,8 @@ def get_submodule(module, name):
22772308

22782309
setattr(parent_module, current_module_name, expanded_module)
22792310

2311+
return has_param_with_shape_update
2312+
22802313

22812314
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
22822315
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

0 commit comments

Comments
 (0)