Skip to content

Commit b8a371e

Browse files
committed
vace
1 parent f81d4c6 commit b8a371e

File tree

1 file changed

+53
-3
lines changed

1 file changed

+53
-3
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4859,11 +4859,57 @@ def _maybe_expand_t2v_lora_for_vace(
48594859
transformer: torch.nn.Module,
48604860
state_dict,
48614861
):
4862+
target_device = transformer.device
4863+
if hasattr(transformer, 'vace_blocks'):
4864+
inferred_rank_for_vace = None
4865+
lora_weights_dtype_for_vace = next(iter(transformer.parameters())).dtype # Fallback dtype
4866+
4867+
for k_lora_any, v_lora_tensor_any in state_dict.items():
4868+
if k_lora_any.endswith(".lora_A.weight"):
4869+
inferred_rank_for_vace = v_lora_tensor_any.shape[0]
4870+
lora_weights_dtype_for_vace = v_lora_tensor_any.dtype
4871+
break # Found one, good enough for rank and dtype
4872+
4873+
if inferred_rank_for_vace is not None:
4874+
# Determine if the LoRA format (as potentially modified by I2V expansion) includes bias
4875+
# This re-checks 'has_bias' based on the *current* state_dict.
4876+
current_lora_has_bias = any(".lora_B.bias" in k for k in state_dict.keys())
4877+
4878+
for i, vace_block_module_in_model in enumerate(transformer.vace_blocks):
4879+
# Specifically target proj_out as per the error message
4880+
if hasattr(vace_block_module_in_model, 'proj_out') and \
4881+
isinstance(vace_block_module_in_model.proj_out, nn.Linear):
4882+
4883+
proj_out_linear_layer_in_model = vace_block_module_in_model.proj_out
4884+
4885+
vace_lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight"
4886+
vace_lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight"
4887+
4888+
if vace_lora_A_key not in state_dict:
4889+
state_dict[vace_lora_A_key] = torch.zeros(
4890+
(inferred_rank_for_vace, proj_out_linear_layer_in_model.in_features),
4891+
device=target_device, dtype=lora_weights_dtype_for_vace
4892+
)
4893+
4894+
if vace_lora_B_key not in state_dict:
4895+
state_dict[vace_lora_B_key] = torch.zeros(
4896+
(proj_out_linear_layer_in_model.out_features, inferred_rank_for_vace),
4897+
device=target_device, dtype=lora_weights_dtype_for_vace
4898+
)
4899+
4900+
# Use 'current_lora_has_bias' to decide on padding bias for VACE blocks
4901+
if current_lora_has_bias and proj_out_linear_layer_in_model.bias is not None:
4902+
vace_lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias"
4903+
if vace_lora_B_bias_key not in state_dict:
4904+
state_dict[vace_lora_B_bias_key] = torch.zeros_like(
4905+
proj_out_linear_layer_in_model.bias, # Shape from model's bias
4906+
device=target_device # Dtype from model's bias implicitly by zeros_like
4907+
)
4908+
4909+
print("AFTER 2:", list(state_dict.keys()))
4910+
return state_dict
48624911

4863-
if not hasattr(transformer, 'vace_blocks'):
4864-
return state_dict
48654912

4866-
target_device = transformer.device
48674913

48684914
return state_dict
48694915

@@ -4915,6 +4961,10 @@ def load_lora_weights(
49154961
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
49164962
state_dict=state_dict,
49174963
)
4964+
state_dict = self._maybe_expand_t2v_lora_for_vace(
4965+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4966+
state_dict=state_dict,
4967+
)
49184968
is_correct_format = all("lora" in key for key in state_dict.keys())
49194969
if not is_correct_format:
49204970
raise ValueError("Invalid LoRA checkpoint.")

0 commit comments

Comments
 (0)