Skip to content

Commit f351017

Browse files
committed
@hlky t2v->i2v
1 parent 3be6706 commit f351017

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4249,6 +4249,31 @@ def lora_state_dict(
42494249

42504250
return state_dict
42514251

4252+
@classmethod
4253+
def maybe_expand_t2v_lora_for_i2v(
4254+
cls,
4255+
transformer: torch.nn.Module,
4256+
state_dict,
4257+
):
4258+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
4259+
is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict)
4260+
if not is_i2v_lora:
4261+
return state_dict
4262+
4263+
if transformer.config.image_dim is None:
4264+
return state_dict
4265+
4266+
for i in range(num_blocks):
4267+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4268+
state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4269+
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"]
4270+
)
4271+
state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4272+
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"]
4273+
)
4274+
4275+
return state_dict
4276+
42524277
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
42534278
def load_lora_weights(
42544279
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
@@ -4287,7 +4312,10 @@ def load_lora_weights(
42874312

42884313
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
42894314
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4290-
4315+
state_dict = self._maybe_expand_t2v_lora_for_i2v(
4316+
transformer = getattr(self, self.transformer_name) if not hasattr(self,
4317+
"transformer") else self.transformer,
4318+
state_dict = state_dict)
42914319
is_correct_format = all("lora" in key for key in state_dict.keys())
42924320
if not is_correct_format:
42934321
raise ValueError("Invalid LoRA checkpoint.")

0 commit comments

Comments
 (0)