Skip to content

Commit 0847255

Browse files
committed
load into 2nd transformer
1 parent 96864fb commit 0847255

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5064,7 +5064,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
50645064
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
50655065
"""
50665066

5067-
_lora_loadable_modules = ["transformer"]
5067+
_lora_loadable_modules = ["transformer", "transformer_2"]
50685068
transformer_name = TRANSFORMER_NAME
50695069

50705070
@classmethod
@@ -5269,15 +5269,28 @@ def load_lora_weights(
52695269
if not is_correct_format:
52705270
raise ValueError("Invalid LoRA checkpoint.")
52715271

5272-
self.load_lora_into_transformer(
5273-
state_dict,
5274-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5275-
adapter_name=adapter_name,
5276-
metadata=metadata,
5277-
_pipeline=self,
5278-
low_cpu_mem_usage=low_cpu_mem_usage,
5279-
hotswap=hotswap,
5280-
)
5272+
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
5273+
if load_into_transformer_2:
5274+
self.load_lora_into_transformer(
5275+
state_dict,
5276+
transformer=self.transformer_2,
5277+
adapter_name=adapter_name,
5278+
metadata=metadata,
5279+
_pipeline=self,
5280+
low_cpu_mem_usage=low_cpu_mem_usage,
5281+
hotswap=hotswap,
5282+
)
5283+
else:
5284+
self.load_lora_into_transformer(
5285+
state_dict,
5286+
transformer=getattr(self, self.transformer_name) if not hasattr(self,
5287+
"transformer") else self.transformer,
5288+
adapter_name=adapter_name,
5289+
metadata=metadata,
5290+
_pipeline=self,
5291+
low_cpu_mem_usage=low_cpu_mem_usage,
5292+
hotswap=hotswap,
5293+
)
52815294

52825295
@classmethod
52835296
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel

0 commit comments

Comments
 (0)