@@ -4813,89 +4813,43 @@ def _maybe_expand_t2v_lora_for_i2v(
48134813 if transformer .config .image_dim is None :
48144814 return state_dict
48154815
4816- # Determine the target device from the transformer
4817- try :
4818- target_device = next (transformer .parameters ()).device
4819- except StopIteration :
4820- # Fallback if transformer has no parameters (should be rare for a full model)
4821- # Try to infer from existing tensors in state_dict, else default to CPU
4822- target_device = torch .device ("cpu" )
4823- for v in state_dict .values ():
4824- if torch .is_tensor (v ):
4825- target_device = v .device
4826- break
4827-
4828- # Debug prints, can be removed after verification
4829- logger .debug (f"Target device for new LoRA tensors: { target_device } " )
4830- logger .debug (f"Keys BEFORE _maybe_expand_t2v_lora_for_i2v: { list (state_dict .keys ())[:5 ]} ..." )
4816+ target_device = transformer .device
48314817
48324818 if any (k .startswith ("transformer.blocks." ) for k in state_dict ):
4833- block_indices = set ()
4834- # Iterate over a copy of keys if state_dict might be modified during iteration elsewhere (not here though)
4835- for k in list (state_dict .keys ()):
4836- if k .startswith ("transformer.blocks." ) and ".attn2.to_k.lora_A.weight" in k :
4837- try :
4838- block_idx_str = k .split ("blocks." )[1 ].split ("." )[0 ]
4839- block_indices .add (int (block_idx_str ))
4840- except (IndexError , ValueError ):
4841- logger .warning (f"Could not parse block index from key: { k } " )
4842- continue
4843-
4844- if not block_indices :
4845- logger .debug (
4846- "No valid blocks found for T2V to I2V expansion referencing based on '.attn2.to_k.lora_A.weight'." )
4847- return state_dict
4848-
4849- num_total_blocks = max (block_indices ) + 1
4850-
4819+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict if "blocks." in k })
48514820 is_i2v_lora = any ("add_k_proj" in k for k in state_dict ) and any ("add_v_proj" in k for k in state_dict )
4821+ has_bias = any (".lora_B.bias" in k for k in state_dict )
48524822
4853- # Check for bias keys that would have been converted by _convert_non_diffusers_wan_lora_to_diffusers
4854- # e.g., 'transformer.blocks.0.attn2.to_k.lora_B.bias'
4855- # This helps decide if zero biases should be added for the new projections.
4856- has_bias_key_pattern_in_sd = any (".lora_B.bias" in k for k in state_dict )
4857-
4858- if is_i2v_lora : # If it's already an I2V LoRA, no expansion needed
4823+ if is_i2v_lora :
48594824 return state_dict
48604825
4861- logger .info (
4862- "Adapting a T2V LoRA for I2V model by adding zero-initialized weights for image-specific cross-attention layers."
4863- )
4864- for i in range (num_total_blocks ):
4865- # Define reference key patterns carefully. These keys should exist if the block `i` was part of the T2V LoRA.
4866- ref_key_lora_A = f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight"
4867- ref_key_lora_B = f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight"
4868-
4869- # Only proceed if the reference LoRA weights for this block exist in the T2V LoRA
4870- if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict :
4871- continue
4826+ for i in range (num_blocks ):
4827+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4828+ # These keys should exist if the block `i` was part of the T2V LoRA.
4829+ ref_key_lora_A = f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight"
4830+ ref_key_lora_B = f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight"
48724831
4873- ref_lora_A_weight = state_dict [ref_key_lora_A ]
4874- ref_lora_B_weight = state_dict [ref_key_lora_B ]
4875-
4876- # Use dtype from reference LoRA tensors, device from transformer
4877- lora_dtype = ref_lora_A_weight .dtype
4832+ if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict :
4833+ continue
4834+ lora_dtype = ref_lora_A_weight .dtype
48784835
4879- for _unused_orig_name , diffusers_name_part in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4880- # Create new tensors on the transformer's device and with the LoRA's dtype
4881- state_dict [f"transformer.blocks.{ i } .attn2.{ diffusers_name_part } .lora_A.weight" ] = torch .zeros (
4836+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros (
48824837 ref_lora_A_weight .shape , device = target_device , dtype = lora_dtype
48834838 )
4884- state_dict [f"transformer.blocks.{ i } .attn2.{ diffusers_name_part } .lora_B.weight" ] = torch .zeros (
4885- ref_lora_B_weight . shape , device = target_device , dtype = lora_dtype
4839+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4840+ state_dict [ f"transformer.blocks. { i } .attn2.to_k.lora_B.weight" ] , device = target_device
48864841 )
48874842
4888- # If the original LoRA had biases (indicated by has_bias_key_pattern_in_sd )
4843+ # If the original LoRA had biases (indicated by has_bias )
48894844 # AND the specific reference bias key exists for this block.
4845+
48904846 ref_key_lora_B_bias = f"transformer.blocks.{ i } .attn2.to_k.lora_B.bias"
4891- if has_bias_key_pattern_in_sd and ref_key_lora_B_bias in state_dict :
4847+ if has_bias and ref_key_lora_B_bias in state_dict :
48924848 ref_lora_B_bias_tensor = state_dict [ref_key_lora_B_bias ]
48934849 state_dict [f"transformer.blocks.{ i } .attn2.{ diffusers_name_part } .lora_B.bias" ] = torch .zeros (
48944850 ref_lora_B_bias_tensor .shape , device = target_device , dtype = lora_dtype
48954851 )
48964852
4897- # new_keys = set(state_dict.keys()) - orig_keys
4898- # logger.debug(f"Keys ADDED by _maybe_expand_t2v_lora_for_i2v: {new_keys}")
48994853 return state_dict
49004854
49014855 def load_lora_weights (
0 commit comments