Skip to content

Commit 3d735b4

Browse files
committed
lora expansion with dummy zeros.
1 parent d041dd5 commit 3d735b4

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,6 +1863,7 @@ def load_lora_weights(
18631863
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
18641864
"To get a comprehensive list of parameter names that were modified, enable debug logging."
18651865
)
1866+
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(transformer=transformer, lora_state_dict=transformer_lora_state_dict)
18661867

18671868
if len(transformer_lora_state_dict) > 0:
18681869
self.load_lora_into_transformer(
@@ -2373,6 +2374,32 @@ def _maybe_expand_transformer_param_shape_or_error_(
23732374

23742375
return has_param_with_shape_update
23752376

2377+
@classmethod
2378+
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
2379+
expanded_module_names = set()
2380+
transformer_state_dict = transformer.state_dict()
2381+
lora_module_names = set([k.replace(".lora_A.weight", "") for k in lora_state_dict if "lora_A" in k])
2382+
lora_module_names = sorted(lora_module_names)
2383+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2384+
2385+
for k in lora_module_names:
2386+
base_param_name = f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(f'{cls.transformer_name}.', '')}.weight"
2387+
base_weight_param = transformer_state_dict[base_param_name]
2388+
lora_A_param = lora_state_dict[f"{k}.lora_A.weight"]
2389+
# lora_B_param = lora_state_dict[f"{k}.lora_B.weight"]
2390+
2391+
if base_weight_param.shape[1] > lora_A_param.shape[1]:
2392+
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
2393+
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
2394+
expanded_state_dict_weight[:, :lora_A_param.shape[1]].copy_(lora_A_param)
2395+
lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight
2396+
expanded_module_names.add(k)
2397+
2398+
if expanded_module_names:
2399+
logger.info(f"Found some LoRA modules for which the weights were expanded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new.")
2400+
return lora_state_dict
2401+
2402+
23762403

23772404
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
23782405
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

0 commit comments

Comments
 (0)