Skip to content

Commit d3e177c

Browse files
committed
fix working 🥳
1 parent ed91c53 commit d3e177c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,9 +2400,9 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24002400
lora_A_param = lora_state_dict[f"{k}.lora_A.weight"]
24012401

24022402
if base_weight_param.shape[1] > lora_A_param.shape[1]:
2403-
# could be made more advanced with `repeats`.
2404-
# have tried zero-padding but that doesn't work, either.
2405-
expanded_state_dict_weight = torch.cat([lora_A_param, lora_A_param], dim=1)
2403+
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
2404+
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
2405+
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
24062406
lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight
24072407
expanded_module_names.add(k)
24082408

0 commit comments

Comments
 (0)