Skip to content

Commit ed91c53

Browse files
committed
updates
1 parent 5001efe commit ed91c53

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,7 +1863,9 @@ def load_lora_weights(
18631863
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
18641864
"To get a comprehensive list of parameter names that were modified, enable debug logging."
18651865
)
1866-
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(transformer=transformer, lora_state_dict=transformer_lora_state_dict)
1866+
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1867+
transformer=transformer, lora_state_dict=transformer_lora_state_dict
1868+
)
18671869

18681870
if len(transformer_lora_state_dict) > 0:
18691871
self.load_lora_into_transformer(
@@ -2385,29 +2387,32 @@ def _maybe_expand_transformer_param_shape_or_error_(
23852387
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
23862388
expanded_module_names = set()
23872389
transformer_state_dict = transformer.state_dict()
2388-
lora_module_names = set([k.replace(".lora_A.weight", "") for k in lora_state_dict if "lora_A" in k])
2389-
lora_module_names = sorted(lora_module_names)
2390+
lora_module_names = sorted({k.replace(".lora_A.weight", "") for k in lora_state_dict if "lora_A" in k})
23902391
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
23912392

23922393
for k in lora_module_names:
2393-
base_param_name = f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(f'{cls.transformer_name}.', '')}.weight"
2394+
base_param_name = (
2395+
f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight"
2396+
if is_peft_loaded
2397+
else f"{k.replace(f'{cls.transformer_name}.', '')}.weight"
2398+
)
23942399
base_weight_param = transformer_state_dict[base_param_name]
23952400
lora_A_param = lora_state_dict[f"{k}.lora_A.weight"]
2396-
# lora_B_param = lora_state_dict[f"{k}.lora_B.weight"]
23972401

23982402
if base_weight_param.shape[1] > lora_A_param.shape[1]:
2399-
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
2400-
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
2401-
expanded_state_dict_weight[:, :lora_A_param.shape[1]].copy_(lora_A_param)
2403+
# could be made more advanced with `repeats`.
2404+
# have tried zero-padding but that doesn't work, either.
2405+
expanded_state_dict_weight = torch.cat([lora_A_param, lora_A_param], dim=1)
24022406
lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight
24032407
expanded_module_names.add(k)
24042408

24052409
if expanded_module_names:
2406-
logger.info(f"Found some LoRA modules for which the weights were expanded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new.")
2410+
logger.info(
2411+
f"Found some LoRA modules for which the weights were expanded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
2412+
)
24072413
return lora_state_dict
24082414

24092415

2410-
24112416
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
24122417
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
24132418
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):

0 commit comments

Comments
 (0)