@@ -558,6 +558,72 @@ def test_load_regular_lora(self):
558558        self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  in_features  *  2 )
559559        self .assertFalse (np .allclose (original_output , lora_output , atol = 1e-3 , rtol = 1e-3 ))
560560
561+     def  test_lora_unload_with_parameter_expanded_shapes (self ):
562+         components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
563+ 
564+         logger  =  logging .get_logger ("diffusers.loaders.lora_pipeline" )
565+         logger .setLevel (logging .DEBUG )
566+ 
567+         # Change the transformer config to mimic a real use case. 
568+         num_channels_without_control  =  4 
569+         transformer  =  FluxTransformer2DModel .from_config (
570+             components ["transformer" ].config , in_channels = num_channels_without_control 
571+         ).to (torch_device )
572+         self .assertTrue (
573+             transformer .config .in_channels  ==  num_channels_without_control ,
574+             f"Expected { num_channels_without_control } { transformer .config .in_channels = }  ,
575+         )
576+ 
577+         # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. 
578+         components ["transformer" ] =  transformer 
579+         pipe  =  FluxPipeline (** components )
580+         pipe  =  pipe .to (torch_device )
581+         pipe .set_progress_bar_config (disable = None )
582+ 
583+         _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
584+         control_image  =  inputs .pop ("control_image" )
585+         original_out  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
586+ 
587+         control_pipe  =  self .pipeline_class (** components )
588+         out_features , in_features  =  control_pipe .transformer .x_embedder .weight .shape 
589+         rank  =  4 
590+ 
591+         dummy_lora_A  =  torch .nn .Linear (2  *  in_features , rank , bias = False )
592+         dummy_lora_B  =  torch .nn .Linear (rank , out_features , bias = False )
593+         lora_state_dict  =  {
594+             "transformer.x_embedder.lora_A.weight" : dummy_lora_A .weight ,
595+             "transformer.x_embedder.lora_B.weight" : dummy_lora_B .weight ,
596+         }
597+         with  CaptureLogger (logger ) as  cap_logger :
598+             control_pipe .load_lora_weights (lora_state_dict , "adapter-1" )
599+             self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
600+ 
601+         inputs ["control_image" ] =  control_image 
602+         lora_out  =  control_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
603+ 
604+         self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
605+         self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  2  *  in_features )
606+         self .assertTrue (pipe .transformer .config .in_channels  ==  2  *  in_features )
607+         self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
608+ 
609+         control_pipe .unload_lora_weights ()
610+         self .assertTrue (
611+             control_pipe .transformer .config .in_channels  ==  num_channels_without_control ,
612+             f"Expected { num_channels_without_control } { control_pipe .transformer .config .in_channels = }  ,
613+         )
614+         loaded_pipe  =  FluxPipeline .from_pipe (control_pipe )
615+         self .assertTrue (
616+             loaded_pipe .transformer .config .in_channels  ==  num_channels_without_control ,
617+             f"Expected { num_channels_without_control } { loaded_pipe .transformer .config .in_channels = }  ,
618+         )
619+         inputs .pop ("control_image" )
620+         unloaded_lora_out  =  loaded_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
621+ 
622+         self .assertFalse (np .allclose (unloaded_lora_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
623+         self .assertTrue (np .allclose (unloaded_lora_out , original_out , atol = 1e-4 , rtol = 1e-4 ))
624+         self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  in_features )
625+         self .assertTrue (pipe .transformer .config .in_channels  ==  in_features )
626+ 
561627    @unittest .skip ("Not supported in Flux." ) 
562628    def  test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
563629        pass 
0 commit comments