@@ -5064,7 +5064,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
50645064    Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. 
50655065    """ 
50665066
5067-     _lora_loadable_modules  =  ["transformer" ]
5067+     _lora_loadable_modules  =  ["transformer" ,  "transformer_2" ]
50685068    transformer_name  =  TRANSFORMER_NAME 
50695069
50705070    @classmethod  
@@ -5269,15 +5269,28 @@ def load_lora_weights(
52695269        if  not  is_correct_format :
52705270            raise  ValueError ("Invalid LoRA checkpoint." )
52715271
5272-         self .load_lora_into_transformer (
5273-             state_dict ,
5274-             transformer = getattr (self , self .transformer_name ) if  not  hasattr (self , "transformer" ) else  self .transformer ,
5275-             adapter_name = adapter_name ,
5276-             metadata = metadata ,
5277-             _pipeline = self ,
5278-             low_cpu_mem_usage = low_cpu_mem_usage ,
5279-             hotswap = hotswap ,
5280-         )
5272+         load_into_transformer_2  =  kwargs .pop ("load_into_transformer_2" , False )
5273+         if  load_into_transformer_2 :
5274+             self .load_lora_into_transformer (
5275+                 state_dict ,
5276+                 transformer = self .transformer_2 ,
5277+                 adapter_name = adapter_name ,
5278+                 metadata = metadata ,
5279+                 _pipeline = self ,
5280+                 low_cpu_mem_usage = low_cpu_mem_usage ,
5281+                 hotswap = hotswap ,
5282+             )
5283+         else :
5284+             self .load_lora_into_transformer (
5285+                 state_dict ,
5286+                 transformer = getattr (self , self .transformer_name ) if  not  hasattr (self ,
5287+                                                                                 "transformer" ) else  self .transformer ,
5288+                 adapter_name = adapter_name ,
5289+                 metadata = metadata ,
5290+                 _pipeline = self ,
5291+                 low_cpu_mem_usage = low_cpu_mem_usage ,
5292+                 hotswap = hotswap ,
5293+             )
52815294
52825295    @classmethod  
52835296    # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel 
0 commit comments