-
Couldn't load subscription status.
- Fork 6.4k
[Wan LoRAs] make T2V LoRAs compatible with Wan I2V #11107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
f351017
5e6a15b
ccdc4fd
b5689c4
fe2d3b4
6637a12
63e581c
9fa3d93
051f534
3834c16
292d618
440001c
be5a01a
92aabcb
51c570d
f5b5986
0f3a48f
d2dd6ae
c464455
86cbc0f
6c39465
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4249,7 +4249,32 @@ def lora_state_dict( | |
|
|
||
| return state_dict | ||
|
|
||
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights | ||
| @classmethod | ||
| def _maybe_expand_t2v_lora_for_i2v( | ||
| cls, | ||
| transformer: torch.nn.Module, | ||
| state_dict, | ||
| ): | ||
| if any(k.startswith("blocks.") for k in state_dict): | ||
| num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) | ||
| is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict) | ||
| if not is_i2v_lora: | ||
linoytsaban marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return state_dict | ||
|
||
|
|
||
| if transformer.config.image_dim is None: | ||
| return state_dict | ||
|
||
|
|
||
| for i in range(num_blocks): | ||
| for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): | ||
| state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( | ||
| state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"] | ||
| ) | ||
| state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( | ||
| state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"] | ||
| ) | ||
linoytsaban marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return state_dict | ||
|
|
||
| def load_lora_weights( | ||
| self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs | ||
| ): | ||
|
|
@@ -4287,7 +4312,11 @@ def load_lora_weights( | |
|
|
||
| # First, ensure that the checkpoint is a compatible one and can be successfully loaded. | ||
| state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||
|
|
||
| # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers | ||
| state_dict = self._maybe_expand_t2v_lora_for_i2v( | ||
| transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, | ||
| state_dict=state_dict, | ||
| ) | ||
| is_correct_format = all("lora" in key for key in state_dict.keys()) | ||
| if not is_correct_format: | ||
| raise ValueError("Invalid LoRA checkpoint.") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.