@@ -606,7 +606,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
606606        self .assertTrue (pipe .transformer .config .in_channels  ==  2  *  in_features )
607607        self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
608608
609-         control_pipe .unload_lora_weights ()
609+         control_pipe .unload_lora_weights (reset_to_overwrriten_params = True )
610610        self .assertTrue (
611611            control_pipe .transformer .config .in_channels  ==  num_channels_without_control ,
612612            f"Expected { num_channels_without_control } { control_pipe .transformer .config .in_channels = }  ,
@@ -624,6 +624,65 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
624624        self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  in_features )
625625        self .assertTrue (pipe .transformer .config .in_channels  ==  in_features )
626626
627+     def  test_lora_unload_with_parameter_expanded_shapes_and_no_reset (self ):
628+         components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
629+ 
630+         logger  =  logging .get_logger ("diffusers.loaders.lora_pipeline" )
631+         logger .setLevel (logging .DEBUG )
632+ 
633+         # Change the transformer config to mimic a real use case. 
634+         num_channels_without_control  =  4 
635+         transformer  =  FluxTransformer2DModel .from_config (
636+             components ["transformer" ].config , in_channels = num_channels_without_control 
637+         ).to (torch_device )
638+         self .assertTrue (
639+             transformer .config .in_channels  ==  num_channels_without_control ,
640+             f"Expected { num_channels_without_control } { transformer .config .in_channels = }  ,
641+         )
642+ 
643+         # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. 
644+         components ["transformer" ] =  transformer 
645+         pipe  =  FluxPipeline (** components )
646+         pipe  =  pipe .to (torch_device )
647+         pipe .set_progress_bar_config (disable = None )
648+ 
649+         _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
650+         control_image  =  inputs .pop ("control_image" )
651+         original_out  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
652+ 
653+         control_pipe  =  self .pipeline_class (** components )
654+         out_features , in_features  =  control_pipe .transformer .x_embedder .weight .shape 
655+         rank  =  4 
656+ 
657+         dummy_lora_A  =  torch .nn .Linear (2  *  in_features , rank , bias = False )
658+         dummy_lora_B  =  torch .nn .Linear (rank , out_features , bias = False )
659+         lora_state_dict  =  {
660+             "transformer.x_embedder.lora_A.weight" : dummy_lora_A .weight ,
661+             "transformer.x_embedder.lora_B.weight" : dummy_lora_B .weight ,
662+         }
663+         with  CaptureLogger (logger ) as  cap_logger :
664+             control_pipe .load_lora_weights (lora_state_dict , "adapter-1" )
665+             self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
666+ 
667+         inputs ["control_image" ] =  control_image 
668+         lora_out  =  control_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
669+ 
670+         self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
671+         self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  2  *  in_features )
672+         self .assertTrue (pipe .transformer .config .in_channels  ==  2  *  in_features )
673+         self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
674+ 
675+         control_pipe .unload_lora_weights (reset_to_overwrriten_params = False )
676+         self .assertTrue (
677+             control_pipe .transformer .config .in_channels  ==  2  *  num_channels_without_control ,
678+             f"Expected { num_channels_without_control } { control_pipe .transformer .config .in_channels = }  ,
679+         )
680+         no_lora_out  =  control_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
681+ 
682+         self .assertFalse (np .allclose (no_lora_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
683+         self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  in_features  *  2 )
684+         self .assertTrue (pipe .transformer .config .in_channels  ==  in_features  *  2 )
685+ 
627686    @unittest .skip ("Not supported in Flux." ) 
628687    def  test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
629688        pass 
0 commit comments