@@ -4813,30 +4813,89 @@ def _maybe_expand_t2v_lora_for_i2v(
48134813 if transformer .config .image_dim is None :
48144814 return state_dict
48154815
4816- target_device = transformer .device
4816+ # Determine the target device from the transformer
4817+ try :
4818+ target_device = next (transformer .parameters ()).device
4819+ except StopIteration :
4820+ # Fallback if transformer has no parameters (should be rare for a full model)
4821+ # Try to infer from existing tensors in state_dict, else default to CPU
4822+ target_device = torch .device ("cpu" )
4823+ for v in state_dict .values ():
4824+ if torch .is_tensor (v ):
4825+ target_device = v .device
4826+ break
4827+
4828+ # Debug prints, can be removed after verification
4829+ logger .debug (f"Target device for new LoRA tensors: { target_device } " )
4830+ logger .debug (f"Keys BEFORE _maybe_expand_t2v_lora_for_i2v: { list (state_dict .keys ())[:5 ]} ..." )
48174831
48184832 if any (k .startswith ("transformer.blocks." ) for k in state_dict ):
4819- num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict if "blocks." in k })
4833+ block_indices = set ()
4834+ # Iterate over a copy of keys if state_dict might be modified during iteration elsewhere (not here though)
4835+ for k in list (state_dict .keys ()):
4836+ if k .startswith ("transformer.blocks." ) and ".attn2.to_k.lora_A.weight" in k :
4837+ try :
4838+ block_idx_str = k .split ("blocks." )[1 ].split ("." )[0 ]
4839+ block_indices .add (int (block_idx_str ))
4840+ except (IndexError , ValueError ):
4841+ logger .warning (f"Could not parse block index from key: { k } " )
4842+ continue
4843+
4844+ if not block_indices :
4845+ logger .debug (
4846+ "No valid blocks found for T2V to I2V expansion referencing based on '.attn2.to_k.lora_A.weight'." )
4847+ return state_dict
4848+
4849+ num_total_blocks = max (block_indices ) + 1
4850+
48204851 is_i2v_lora = any ("add_k_proj" in k for k in state_dict ) and any ("add_v_proj" in k for k in state_dict )
4821- is_bias = any ("bias" in k for k in state_dict )
48224852
4823- if is_i2v_lora :
4853+ # Check for bias keys that would have been converted by _convert_non_diffusers_wan_lora_to_diffusers
4854+ # e.g., 'transformer.blocks.0.attn2.to_k.lora_B.bias'
4855+ # This helps decide if zero biases should be added for the new projections.
4856+ has_bias_key_pattern_in_sd = any (".lora_B.bias" in k for k in state_dict )
4857+
4858+ if is_i2v_lora : # If it's already an I2V LoRA, no expansion needed
48244859 return state_dict
48254860
4826- for i in range (num_blocks ):
4827- for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4828- state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4829- state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight" ], device = target_device
4861+ logger .info (
4862+ "Adapting a T2V LoRA for I2V model by adding zero-initialized weights for image-specific cross-attention layers."
4863+ )
4864+ for i in range (num_total_blocks ):
4865+ # Define reference key patterns carefully. These keys should exist if the block `i` was part of the T2V LoRA.
4866+ ref_key_lora_A = f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight"
4867+ ref_key_lora_B = f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight"
4868+
4869+ # Only proceed if the reference LoRA weights for this block exist in the T2V LoRA
4870+ if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict :
4871+ continue
4872+
4873+ ref_lora_A_weight = state_dict [ref_key_lora_A ]
4874+ ref_lora_B_weight = state_dict [ref_key_lora_B ]
4875+
4876+ # Use dtype from reference LoRA tensors, device from transformer
4877+ lora_dtype = ref_lora_A_weight .dtype
4878+
4879+ for _unused_orig_name , diffusers_name_part in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4880+ # Create new tensors on the transformer's device and with the LoRA's dtype
4881+ state_dict [f"transformer.blocks.{ i } .attn2.{ diffusers_name_part } .lora_A.weight" ] = torch .zeros (
4882+ ref_lora_A_weight .shape , device = target_device , dtype = lora_dtype
48304883 )
4831- state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4832- state_dict [ f"transformer.blocks. { i } .attn2.to_k.lora_B.weight" ] , device = target_device
4884+ state_dict [f"transformer.blocks.{ i } .attn2.{ diffusers_name_part } .lora_B.weight" ] = torch .zeros (
4885+ ref_lora_B_weight . shape , device = target_device , dtype = lora_dtype
48334886 )
4834- if is_bias :
4835- state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.bias" ] = torch .zeros_like (
4836- state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_B.bias" ], device = target_device
4837- )
48384887
4888+ # If the original LoRA had biases (indicated by has_bias_key_pattern_in_sd)
4889+ # AND the specific reference bias key exists for this block.
4890+ ref_key_lora_B_bias = f"transformer.blocks.{ i } .attn2.to_k.lora_B.bias"
4891+ if has_bias_key_pattern_in_sd and ref_key_lora_B_bias in state_dict :
4892+ ref_lora_B_bias_tensor = state_dict [ref_key_lora_B_bias ]
4893+ state_dict [f"transformer.blocks.{ i } .attn2.{ diffusers_name_part } .lora_B.bias" ] = torch .zeros (
4894+ ref_lora_B_bias_tensor .shape , device = target_device , dtype = lora_dtype
4895+ )
48394896
4897+ # new_keys = set(state_dict.keys()) - orig_keys
4898+ # logger.debug(f"Keys ADDED by _maybe_expand_t2v_lora_for_i2v: {new_keys}")
48404899 return state_dict
48414900
48424901 def load_lora_weights (
0 commit comments