@@ -5687,15 +5687,34 @@ def load_lora_weights(
56875687        if  not  is_correct_format :
56885688            raise  ValueError ("Invalid LoRA checkpoint." )
56895689
5690-         self .load_lora_into_transformer (
5691-             state_dict ,
5692-             transformer = getattr (self , self .transformer_name ) if  not  hasattr (self , "transformer" ) else  self .transformer ,
5693-             adapter_name = adapter_name ,
5694-             metadata = metadata ,
5695-             _pipeline = self ,
5696-             low_cpu_mem_usage = low_cpu_mem_usage ,
5697-             hotswap = hotswap ,
5698-         )
5690+         load_into_transformer_2  =  kwargs .pop ("load_into_transformer_2" , False )
5691+         if  load_into_transformer_2 :
5692+             if  not  hasattr (self , "transformer_2" ):
5693+                 raise  ValueError (
5694+                     "Cannot load LoRA into transformer_2: transformer_2 is not available for this model" 
5695+                     "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." 
5696+                 )
5697+             self .load_lora_into_transformer (
5698+                 state_dict ,
5699+                 transformer = self .transformer_2 ,
5700+                 adapter_name = adapter_name ,
5701+                 metadata = metadata ,
5702+                 _pipeline = self ,
5703+                 low_cpu_mem_usage = low_cpu_mem_usage ,
5704+                 hotswap = hotswap ,
5705+             )
5706+         else :
5707+             self .load_lora_into_transformer (
5708+                 state_dict ,
5709+                 transformer = getattr (self , self .transformer_name )
5710+                 if  not  hasattr (self , "transformer" )
5711+                 else  self .transformer ,
5712+                 adapter_name = adapter_name ,
5713+                 metadata = metadata ,
5714+                 _pipeline = self ,
5715+                 low_cpu_mem_usage = low_cpu_mem_usage ,
5716+                 hotswap = hotswap ,
5717+             )
56995718
57005719    @classmethod  
57015720    # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel 
0 commit comments