Skip to content

Commit 0279a06

Browse files
committed
vace padding
1 parent b3394d4 commit 0279a06

File tree

1 file changed

+43
-2
lines changed

1 file changed

+43
-2
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4833,10 +4833,10 @@ def _maybe_expand_t2v_lora_for_i2v(
48334833
continue
48344834

48354835
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4836-
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
4836+
state_dict[ref_key_lora_A], device=target_device # Using original ref_key_lora_A
48374837
)
48384838
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4839-
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
4839+
state_dict[ref_key_lora_B], device=target_device # Using original ref_key_lora_B
48404840
)
48414841

48424842
# If the original LoRA had biases (indicated by has_bias)
@@ -4850,6 +4850,47 @@ def _maybe_expand_t2v_lora_for_i2v(
48504850
device=target_device,
48514851
)
48524852

4853+
if hasattr(transformer, 'vace_blocks'):
4854+
inferred_rank_for_vace = None
4855+
lora_weights_dtype_for_vace = next(iter(transformer.parameters())).dtype # Fallback dtype
4856+
4857+
for k_lora_any, v_lora_tensor_any in state_dict.items():
4858+
if k_lora_any.endswith(".lora_A.weight"):
4859+
inferred_rank_for_vace = v_lora_tensor_any.shape[0]
4860+
lora_weights_dtype_for_vace = v_lora_tensor_any.dtype
4861+
break # Found one, good enough for rank and dtype
4862+
4863+
if inferred_rank_for_vace is not None:
4864+
current_lora_has_bias = any(".lora_B.bias" in k for k in state_dict.keys())
4865+
4866+
for i, vace_block_module_in_model in enumerate(transformer.vace_blocks):
4867+
if hasattr(vace_block_module_in_model, 'proj_out'):
4868+
4869+
proj_out_linear_layer_in_model = vace_block_module_in_model.proj_out
4870+
4871+
vace_lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight"
4872+
vace_lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight"
4873+
4874+
if vace_lora_A_key not in state_dict:
4875+
state_dict[vace_lora_A_key] = torch.zeros(
4876+
(inferred_rank_for_vace, proj_out_linear_layer_in_model.in_features),
4877+
device=target_device, dtype=lora_weights_dtype_for_vace
4878+
)
4879+
4880+
if vace_lora_B_key not in state_dict:
4881+
state_dict[vace_lora_B_key] = torch.zeros(
4882+
(proj_out_linear_layer_in_model.out_features, inferred_rank_for_vace),
4883+
device=target_device, dtype=lora_weights_dtype_for_vace
4884+
)
4885+
4886+
if current_lora_has_bias and proj_out_linear_layer_in_model.bias is not None:
4887+
vace_lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias"
4888+
if vace_lora_B_bias_key not in state_dict:
4889+
state_dict[vace_lora_B_bias_key] = torch.zeros_like(
4890+
proj_out_linear_layer_in_model.bias,
4891+
device=target_device
4892+
)
4893+
48534894
return state_dict
48544895

48554896
def load_lora_weights(

0 commit comments

Comments
 (0)