@@ -4249,6 +4249,31 @@ def lora_state_dict(
42494249
42504250 return state_dict
42514251
4252+ @classmethod
4253+ def maybe_expand_t2v_lora_for_i2v (
4254+ cls ,
4255+ transformer : torch .nn .Module ,
4256+ state_dict ,
4257+ ):
4258+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict })
4259+ is_i2v_lora = any ("k_img" in k for k in state_dict ) and any ("v_img" in k for k in state_dict )
4260+ if not is_i2v_lora :
4261+ return state_dict
4262+
4263+ if transformer .config .image_dim is None :
4264+ return state_dict
4265+
4266+ for i in range (num_blocks ):
4267+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4268+ state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4269+ state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_A.weight" ]
4270+ )
4271+ state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4272+ state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_B.weight" ]
4273+ )
4274+
4275+ return state_dict
4276+
42524277 # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
42534278 def load_lora_weights (
42544279 self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
@@ -4287,7 +4312,10 @@ def load_lora_weights(
42874312
42884313 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
42894314 state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4290-
4315+ state_dict = self ._maybe_expand_t2v_lora_for_i2v (
4316+ transformer = getattr (self , self .transformer_name ) if not hasattr (self ,
4317+ "transformer" ) else self .transformer ,
4318+ state_dict = state_dict )
42914319 is_correct_format = all ("lora" in key for key in state_dict .keys ())
42924320 if not is_correct_format :
42934321 raise ValueError ("Invalid LoRA checkpoint." )
0 commit comments