1616from typing import Callable , Dict , List , Optional , Union
1717
1818import torch
19- import re
2019from huggingface_hub .utils import validate_hf_hub_args
2120
2221from ..utils import (
@@ -4806,152 +4805,50 @@ def lora_state_dict(
48064805 return state_dict
48074806
48084807 @classmethod
4809- def _modified_maybe_expand_t2v_lora ( # Renamed for clarity
4810- # cls, # if it were a classmethod
4811- transformer : torch .nn .Module ,
4812- state_dict : Dict [str , torch .Tensor ],
4813- lora_filename_for_rank_inference : Optional [str ] = None # Optional: for rank hint
4814- ) -> Dict [str , torch .Tensor ]:
4808+ def _maybe_expand_t2v_lora_for_i2v (
4809+ cls ,
4810+ transformer : torch .nn .Module ,
4811+ state_dict ,
4812+ ):
4813+ if transformer .config .image_dim is None :
4814+ return state_dict
48154815
48164816 target_device = transformer .device
4817- # Default dtype from transformer, can be refined if LoRA weights have a different one
4818- lora_weights_dtype = next (iter (transformer .parameters ())).dtype
4819-
4820- # --- Infer LoRA rank and potentially refine dtype from existing LoRA weights ---
4821- inferred_rank = None
4822- if state_dict : # If LoRA state_dict already has entries from the T2V LoRA
4823- for k , v_tensor in state_dict .items ():
4824- if k .endswith (".lora_A.weight" ): # Standard LoRA weight key part
4825- inferred_rank = v_tensor .shape [0 ] # rank is the output dim of lora_A
4826- lora_weights_dtype = v_tensor .dtype # Use dtype of existing LoRA weights
4827- break # Found rank and dtype
4828-
4829- if inferred_rank is None and lora_filename_for_rank_inference :
4830- match = re .search (r"rank(\d+)" , lora_filename_for_rank_inference , re .IGNORECASE )
4831- if match :
4832- inferred_rank = int (match .group (1 ))
4833- print (f"INFO: Inferred LoRA rank { inferred_rank } from filename for padding." )
4834-
4835- # Determine if the original LoRA format (the T2V part) uses biases for lora_B
4836- lora_format_has_bias = any (".lora_B.bias" in k for k in state_dict .keys ())
4837-
4838- # --- Part 1: Original I2V expansion for standard transformer.blocks ---
4839- # (Assuming transformer.config and transformer.blocks structure for this part)
4840- if hasattr (transformer , 'config' ) and hasattr (transformer .config , 'image_dim' ) and \
4841- transformer .config .image_dim is not None and hasattr (transformer , 'blocks' ):
4842-
4843- standard_block_keys_present = any (k .startswith ("transformer.blocks." ) for k in state_dict )
4844-
4845- if standard_block_keys_present and inferred_rank is not None :
4846- num_blocks_in_lora = 0
4847- block_indices = set ()
4848- for k_lora in state_dict :
4849- if "transformer.blocks." in k_lora :
4850- try :
4851- block_idx_str = k_lora .split ("transformer.blocks." )[1 ].split ("." )[0 ]
4852- if block_idx_str .isdigit ():
4853- block_indices .add (int (block_idx_str ))
4854- except IndexError :
4855- pass
4856- if block_indices :
4857- num_blocks_in_lora = max (block_indices ) + 1
4858-
4859- is_i2v_lora_standard_blocks = any (
4860- k .startswith ("transformer.blocks." ) and "add_k_proj" in k for k in state_dict
4861- ) and any (
4862- k .startswith ("transformer.blocks." ) and "add_v_proj" in k for k in state_dict
4863- )
48644817
4865- if not is_i2v_lora_standard_blocks and num_blocks_in_lora > 0 :
4866- print (f"INFO: Expanding T2V LoRA for I2V compatibility (standard blocks). Rank: { inferred_rank } " )
4867- for i in range (num_blocks_in_lora ):
4868- # Check if block 'i' relevant parts are in the T2V LoRA
4869- ref_key_lora_A = f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight"
4870- if ref_key_lora_A not in state_dict :
4871- continue # This block's specific part wasn't in the LoRA.
4872-
4873- try :
4874- model_block = transformer .blocks [i ]
4875- # Ensure these target layers exist in the model's standard block
4876- if not (hasattr (model_block , 'attn2' ) and \
4877- hasattr (model_block .attn2 , 'add_k_proj' ) and \
4878- hasattr (model_block .attn2 , 'add_v_proj' )):
4879- continue
4880- add_k_proj_layer = model_block .attn2 .add_k_proj
4881- add_v_proj_layer = model_block .attn2 .add_v_proj
4882- except (AttributeError , IndexError ):
4883- print (f"WARN: Cannot access standard block { i } or its I2V layers for expansion." )
4884- continue
4885-
4886- for proj_name_suffix , model_linear_layer in [("add_k_proj" , add_k_proj_layer ),
4887- ("add_v_proj" , add_v_proj_layer )]:
4888- if not isinstance (model_linear_layer , nn .Linear ): continue
4889-
4890- lora_A_key = f"transformer.blocks.{ i } .attn2.{ proj_name_suffix } .lora_A.weight"
4891- lora_B_key = f"transformer.blocks.{ i } .attn2.{ proj_name_suffix } .lora_B.weight"
4892-
4893- if lora_A_key not in state_dict :
4894- state_dict [lora_A_key ] = torch .zeros (
4895- (inferred_rank , model_linear_layer .in_features ),
4896- device = target_device , dtype = lora_weights_dtype
4897- )
4898- if lora_B_key not in state_dict :
4899- state_dict [lora_B_key ] = torch .zeros (
4900- (model_linear_layer .out_features , inferred_rank ),
4901- device = target_device , dtype = lora_weights_dtype
4902- )
4903-
4904- if lora_format_has_bias and model_linear_layer .bias is not None :
4905- lora_B_bias_key = f"transformer.blocks.{ i } .attn2.{ proj_name_suffix } .lora_B.bias"
4906- if lora_B_bias_key not in state_dict :
4907- state_dict [lora_B_bias_key ] = torch .zeros_like (
4908- model_linear_layer .bias , device = target_device ,
4909- dtype = model_linear_layer .bias .dtype
4910- )
4911- elif inferred_rank is None :
4912- print ("INFO: LoRA rank not inferred. Skipping I2V expansion for standard blocks." )
4913- # else: not standard_block_keys_present or no I2V capability.
4914-
4915- # --- Part 2: Pad LoRA for WanVACETransformer3DModel vace_blocks.X.proj_out ---
4916- # Dynamically check for WanVACETransformer3DModel availability for isinstance
4917- VACEModelClass = globals ().get ("WanVACETransformer3DModel" )
4918-
4919- if VACEModelClass and isinstance (transformer , VACEModelClass ) and hasattr (transformer , 'vace_blocks' ):
4920- if inferred_rank is None :
4921- print ("WARNING: LoRA rank not determined. Skipping VACE block padding for proj_out." )
4922- else :
4923- print (f"INFO: Transformer is WanVACE. Padding LoRA for vace_blocks.X.proj_out. Rank: { inferred_rank } " )
4924- for i , vace_block_module in enumerate (transformer .vace_blocks ):
4925- if hasattr (vace_block_module , 'proj_out' ) and isinstance (vace_block_module .proj_out , nn .Linear ):
4926- proj_out_layer = vace_block_module .proj_out
4927-
4928- # Keys for the vace_block's proj_out LoRA layers
4929- # These are the keys PEFT expects in the state_dict *before* adding adapter name context
4930- lora_A_key = f"vace_blocks.{ i } .proj_out.lora_A.weight"
4931- lora_B_key = f"vace_blocks.{ i } .proj_out.lora_B.weight"
4932-
4933- if lora_A_key not in state_dict :
4934- state_dict [lora_A_key ] = torch .zeros (
4935- (inferred_rank , proj_out_layer .in_features ),
4936- device = target_device , dtype = lora_weights_dtype
4937- )
4938- # print(f"Padded: {lora_A_key}")
4939-
4940- if lora_B_key not in state_dict :
4941- state_dict [lora_B_key ] = torch .zeros (
4942- (proj_out_layer .out_features , inferred_rank ),
4943- device = target_device , dtype = lora_weights_dtype
4944- )
4945- # print(f"Padded: {lora_B_key}")
4946-
4947- if lora_format_has_bias and proj_out_layer .bias is not None :
4948- lora_B_bias_key = f"vace_blocks.{ i } .proj_out.lora_B.bias"
4949- if lora_B_bias_key not in state_dict :
4950- state_dict [lora_B_bias_key ] = torch .zeros_like (
4951- proj_out_layer .bias , device = target_device , dtype = proj_out_layer .bias .dtype
4952- )
4953- # print(f"Padded: {lora_B_bias_key}")
4954- # else: VACE block 'i' might not have proj_out or it's not Linear.
4818+ if any (k .startswith ("transformer.blocks." ) for k in state_dict ):
4819+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict if "blocks." in k })
4820+ is_i2v_lora = any ("add_k_proj" in k for k in state_dict ) and any ("add_v_proj" in k for k in state_dict )
4821+ has_bias = any (".lora_B.bias" in k for k in state_dict )
4822+
4823+ if is_i2v_lora :
4824+ return state_dict
4825+
4826+ for i in range (num_blocks ):
4827+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4828+ # These keys should exist if the block `i` was part of the T2V LoRA.
4829+ ref_key_lora_A = f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight"
4830+ ref_key_lora_B = f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight"
4831+
4832+ if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict :
4833+ continue
4834+
4835+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4836+ state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight" ], device = target_device
4837+ )
4838+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4839+ state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight" ], device = target_device
4840+ )
4841+
4842+ # If the original LoRA had biases (indicated by has_bias)
4843+ # AND the specific reference bias key exists for this block.
4844+
4845+ ref_key_lora_B_bias = f"transformer.blocks.{ i } .attn2.to_k.lora_B.bias"
4846+ if has_bias and ref_key_lora_B_bias in state_dict :
4847+ ref_lora_B_bias_tensor = state_dict [ref_key_lora_B_bias ]
4848+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.bias" ] = torch .zeros_like (
4849+ ref_lora_B_bias_tensor ,
4850+ device = target_device ,
4851+ )
49554852
49564853 return state_dict
49574854
@@ -5816,4 +5713,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
58165713 def __init__ (self , * args , ** kwargs ):
58175714 deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
58185715 deprecate ("LoraLoaderMixin" , "1.0.0" , deprecation_message )
5819- super ().__init__ (* args , ** kwargs )
5716+ super ().__init__ (* args , ** kwargs )
0 commit comments