@@ -5064,7 +5064,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
50645064 Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
50655065 """
50665066
5067- _lora_loadable_modules = ["transformer" ]
5067+ _lora_loadable_modules = ["transformer" , "transformer_2" ]
50685068 transformer_name = TRANSFORMER_NAME
50695069
50705070 @classmethod
@@ -5269,15 +5269,28 @@ def load_lora_weights(
52695269 if not is_correct_format :
52705270 raise ValueError ("Invalid LoRA checkpoint." )
52715271
5272- self .load_lora_into_transformer (
5273- state_dict ,
5274- transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
5275- adapter_name = adapter_name ,
5276- metadata = metadata ,
5277- _pipeline = self ,
5278- low_cpu_mem_usage = low_cpu_mem_usage ,
5279- hotswap = hotswap ,
5280- )
5272+ load_into_transformer_2 = kwargs .pop ("load_into_transformer_2" , False )
5273+ if load_into_transformer_2 :
5274+ self .load_lora_into_transformer (
5275+ state_dict ,
5276+ transformer = self .transformer_2 ,
5277+ adapter_name = adapter_name ,
5278+ metadata = metadata ,
5279+ _pipeline = self ,
5280+ low_cpu_mem_usage = low_cpu_mem_usage ,
5281+ hotswap = hotswap ,
5282+ )
5283+ else :
5284+ self .load_lora_into_transformer (
5285+ state_dict ,
5286+ transformer = getattr (self , self .transformer_name ) if not hasattr (self ,
5287+ "transformer" ) else self .transformer ,
5288+ adapter_name = adapter_name ,
5289+ metadata = metadata ,
5290+ _pipeline = self ,
5291+ low_cpu_mem_usage = low_cpu_mem_usage ,
5292+ hotswap = hotswap ,
5293+ )
52815294
52825295 @classmethod
52835296 # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
0 commit comments