Skip to content

Commit 8d1de40

Browse files
linoytsabansayakpaulgithub-actions[bot]linoy
authored
[Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora (#12074)
* add alpha * load into 2nd transformer * Update src/diffusers/loaders/lora_conversion_utils.py Co-authored-by: Sayak Paul <[email protected]> * Update src/diffusers/loaders/lora_conversion_utils.py Co-authored-by: Sayak Paul <[email protected]> * pr comments * pr comments * pr comments * fix * fix * Apply style fixes * fix copies * fix * fix copies * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: Sayak Paul <[email protected]> * revert change * revert change * fix copies * up * fix --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: linoy <[email protected]>
1 parent 8cc528c commit 8d1de40

File tree

4 files changed

+150
-55
lines changed

4 files changed

+150
-55
lines changed

docs/source/en/api/pipelines/wan.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
333333

334334
- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
335335

336+
- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more.
337+
336338
## WanPipeline
337339

338340
[[autodoc]] WanPipeline

src/diffusers/loaders/lora_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,11 @@ def set_adapters(
754754
# Decompose weights into weights for denoiser and text encoders.
755755
_component_adapter_weights = {}
756756
for component in self._lora_loadable_modules:
757-
model = getattr(self, component)
757+
model = getattr(self, component, None)
758+
# To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
759+
# Whereas in Wan 2.2, we have two denoisers.
760+
if model is None:
761+
continue
758762

759763
for adapter_name, weights in zip(adapter_names, adapter_weights):
760764
if isinstance(weights, dict):

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18331833
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
18341834
)
18351835

1836+
def get_alpha_scales(down_weight, alpha_key):
1837+
rank = down_weight.shape[0]
1838+
alpha = original_state_dict.pop(alpha_key).item()
1839+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1840+
scale_down = scale
1841+
scale_up = 1.0
1842+
while scale_down * 2 < scale_up:
1843+
scale_down *= 2
1844+
scale_up /= 2
1845+
return scale_down, scale_up
1846+
18361847
for key in list(original_state_dict.keys()):
18371848
if key.endswith((".diff", ".diff_b")) and "norm" in key:
18381849
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18521863
for i in range(min_block, max_block + 1):
18531864
# Self-attention
18541865
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1855-
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1856-
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
1857-
if original_key in original_state_dict:
1858-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1866+
alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
1867+
has_alpha = alpha_key in original_state_dict
1868+
original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1869+
converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
18591870

1860-
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1861-
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
1862-
if original_key in original_state_dict:
1863-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1871+
original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1872+
converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
1873+
1874+
if has_alpha:
1875+
down_weight = original_state_dict.pop(original_key_A)
1876+
up_weight = original_state_dict.pop(original_key_B)
1877+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1878+
converted_state_dict[converted_key_A] = down_weight * scale_down
1879+
converted_state_dict[converted_key_B] = up_weight * scale_up
1880+
1881+
else:
1882+
if original_key_A in original_state_dict:
1883+
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
1884+
if original_key_B in original_state_dict:
1885+
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
18641886

18651887
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
18661888
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18691891

18701892
# Cross-attention
18711893
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1872-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1873-
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1874-
if original_key in original_state_dict:
1875-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1876-
1877-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1878-
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1879-
if original_key in original_state_dict:
1880-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1894+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1895+
has_alpha = alpha_key in original_state_dict
1896+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1897+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1898+
1899+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1900+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1901+
1902+
if original_key_A in original_state_dict:
1903+
down_weight = original_state_dict.pop(original_key_A)
1904+
converted_state_dict[converted_key_A] = down_weight
1905+
if original_key_B in original_state_dict:
1906+
up_weight = original_state_dict.pop(original_key_B)
1907+
converted_state_dict[converted_key_B] = up_weight
1908+
if has_alpha:
1909+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1910+
converted_state_dict[converted_key_A] *= scale_down
1911+
converted_state_dict[converted_key_B] *= scale_up
18811912

18821913
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
18831914
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18861917

18871918
if is_i2v_lora:
18881919
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1889-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1890-
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1891-
if original_key in original_state_dict:
1892-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1893-
1894-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1895-
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1896-
if original_key in original_state_dict:
1897-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1920+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1921+
has_alpha = alpha_key in original_state_dict
1922+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1923+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1924+
1925+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1926+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1927+
1928+
if original_key_A in original_state_dict:
1929+
down_weight = original_state_dict.pop(original_key_A)
1930+
converted_state_dict[converted_key_A] = down_weight
1931+
if original_key_B in original_state_dict:
1932+
up_weight = original_state_dict.pop(original_key_B)
1933+
converted_state_dict[converted_key_B] = up_weight
1934+
if has_alpha:
1935+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1936+
converted_state_dict[converted_key_A] *= scale_down
1937+
converted_state_dict[converted_key_B] *= scale_up
18981938

18991939
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
19001940
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
19031943

19041944
# FFN
19051945
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1906-
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
1907-
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
1908-
if original_key in original_state_dict:
1909-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1910-
1911-
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
1912-
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
1913-
if original_key in original_state_dict:
1914-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1946+
alpha_key = f"blocks.{i}.{o}.alpha"
1947+
has_alpha = alpha_key in original_state_dict
1948+
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
1949+
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
1950+
1951+
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
1952+
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
1953+
1954+
if original_key_A in original_state_dict:
1955+
down_weight = original_state_dict.pop(original_key_A)
1956+
converted_state_dict[converted_key_A] = down_weight
1957+
if original_key_B in original_state_dict:
1958+
up_weight = original_state_dict.pop(original_key_B)
1959+
converted_state_dict[converted_key_B] = up_weight
1960+
if has_alpha:
1961+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1962+
converted_state_dict[converted_key_A] *= scale_down
1963+
converted_state_dict[converted_key_B] *= scale_up
19151964

19161965
original_key = f"blocks.{i}.{o}.diff_b"
19171966
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"

src/diffusers/loaders/lora_pipeline.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5065,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
50655065
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
50665066
"""
50675067

5068-
_lora_loadable_modules = ["transformer"]
5068+
_lora_loadable_modules = ["transformer", "transformer_2"]
50695069
transformer_name = TRANSFORMER_NAME
50705070

50715071
@classmethod
@@ -5270,15 +5270,35 @@ def load_lora_weights(
52705270
if not is_correct_format:
52715271
raise ValueError("Invalid LoRA checkpoint.")
52725272

5273-
self.load_lora_into_transformer(
5274-
state_dict,
5275-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5276-
adapter_name=adapter_name,
5277-
metadata=metadata,
5278-
_pipeline=self,
5279-
low_cpu_mem_usage=low_cpu_mem_usage,
5280-
hotswap=hotswap,
5281-
)
5273+
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
5274+
if load_into_transformer_2:
5275+
if not hasattr(self, "transformer_2"):
5276+
raise AttributeError(
5277+
f"'{type(self).__name__}' object has no attribute transformer_2"
5278+
"Note that Wan2.1 models do not have a transformer_2 component."
5279+
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
5280+
)
5281+
self.load_lora_into_transformer(
5282+
state_dict,
5283+
transformer=self.transformer_2,
5284+
adapter_name=adapter_name,
5285+
metadata=metadata,
5286+
_pipeline=self,
5287+
low_cpu_mem_usage=low_cpu_mem_usage,
5288+
hotswap=hotswap,
5289+
)
5290+
else:
5291+
self.load_lora_into_transformer(
5292+
state_dict,
5293+
transformer=getattr(self, self.transformer_name)
5294+
if not hasattr(self, "transformer")
5295+
else self.transformer,
5296+
adapter_name=adapter_name,
5297+
metadata=metadata,
5298+
_pipeline=self,
5299+
low_cpu_mem_usage=low_cpu_mem_usage,
5300+
hotswap=hotswap,
5301+
)
52825302

52835303
@classmethod
52845304
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
@@ -5668,15 +5688,35 @@ def load_lora_weights(
56685688
if not is_correct_format:
56695689
raise ValueError("Invalid LoRA checkpoint.")
56705690

5671-
self.load_lora_into_transformer(
5672-
state_dict,
5673-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5674-
adapter_name=adapter_name,
5675-
metadata=metadata,
5676-
_pipeline=self,
5677-
low_cpu_mem_usage=low_cpu_mem_usage,
5678-
hotswap=hotswap,
5679-
)
5691+
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
5692+
if load_into_transformer_2:
5693+
if not hasattr(self, "transformer_2"):
5694+
raise AttributeError(
5695+
f"'{type(self).__name__}' object has no attribute transformer_2"
5696+
"Note that Wan2.1 models do not have a transformer_2 component."
5697+
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
5698+
)
5699+
self.load_lora_into_transformer(
5700+
state_dict,
5701+
transformer=self.transformer_2,
5702+
adapter_name=adapter_name,
5703+
metadata=metadata,
5704+
_pipeline=self,
5705+
low_cpu_mem_usage=low_cpu_mem_usage,
5706+
hotswap=hotswap,
5707+
)
5708+
else:
5709+
self.load_lora_into_transformer(
5710+
state_dict,
5711+
transformer=getattr(self, self.transformer_name)
5712+
if not hasattr(self, "transformer")
5713+
else self.transformer,
5714+
adapter_name=adapter_name,
5715+
metadata=metadata,
5716+
_pipeline=self,
5717+
low_cpu_mem_usage=low_cpu_mem_usage,
5718+
hotswap=hotswap,
5719+
)
56805720

56815721
@classmethod
56825722
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel

0 commit comments

Comments
 (0)