@@ -558,6 +558,72 @@ def test_load_regular_lora(self):
558
558
self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == in_features * 2 )
559
559
self .assertFalse (np .allclose (original_output , lora_output , atol = 1e-3 , rtol = 1e-3 ))
560
560
561
+ def test_lora_unload_with_parameter_expanded_shapes (self ):
562
+ components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
563
+
564
+ logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
565
+ logger .setLevel (logging .DEBUG )
566
+
567
+ # Change the transformer config to mimic a real use case.
568
+ num_channels_without_control = 4
569
+ transformer = FluxTransformer2DModel .from_config (
570
+ components ["transformer" ].config , in_channels = num_channels_without_control
571
+ ).to (torch_device )
572
+ self .assertTrue (
573
+ transformer .config .in_channels == num_channels_without_control ,
574
+ f"Expected { num_channels_without_control } channels in the modified transformer but has { transformer .config .in_channels = } " ,
575
+ )
576
+
577
+ # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
578
+ components ["transformer" ] = transformer
579
+ pipe = FluxPipeline (** components )
580
+ pipe = pipe .to (torch_device )
581
+ pipe .set_progress_bar_config (disable = None )
582
+
583
+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
584
+ control_image = inputs .pop ("control_image" )
585
+ original_out = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
586
+
587
+ control_pipe = self .pipeline_class (** components )
588
+ out_features , in_features = control_pipe .transformer .x_embedder .weight .shape
589
+ rank = 4
590
+
591
+ dummy_lora_A = torch .nn .Linear (2 * in_features , rank , bias = False )
592
+ dummy_lora_B = torch .nn .Linear (rank , out_features , bias = False )
593
+ lora_state_dict = {
594
+ "transformer.x_embedder.lora_A.weight" : dummy_lora_A .weight ,
595
+ "transformer.x_embedder.lora_B.weight" : dummy_lora_B .weight ,
596
+ }
597
+ with CaptureLogger (logger ) as cap_logger :
598
+ control_pipe .load_lora_weights (lora_state_dict , "adapter-1" )
599
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
600
+
601
+ inputs ["control_image" ] = control_image
602
+ lora_out = control_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
603
+
604
+ self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
605
+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == 2 * in_features )
606
+ self .assertTrue (pipe .transformer .config .in_channels == 2 * in_features )
607
+ self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
608
+
609
+ control_pipe .unload_lora_weights ()
610
+ self .assertTrue (
611
+ control_pipe .transformer .config .in_channels == num_channels_without_control ,
612
+ f"Expected { num_channels_without_control } channels in the modified transformer but has { control_pipe .transformer .config .in_channels = } " ,
613
+ )
614
+ loaded_pipe = FluxPipeline .from_pipe (control_pipe )
615
+ self .assertTrue (
616
+ loaded_pipe .transformer .config .in_channels == num_channels_without_control ,
617
+ f"Expected { num_channels_without_control } channels in the modified transformer but has { loaded_pipe .transformer .config .in_channels = } " ,
618
+ )
619
+ inputs .pop ("control_image" )
620
+ unloaded_lora_out = loaded_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
621
+
622
+ self .assertFalse (np .allclose (unloaded_lora_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
623
+ self .assertTrue (np .allclose (unloaded_lora_out , original_out , atol = 1e-4 , rtol = 1e-4 ))
624
+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == in_features )
625
+ self .assertTrue (pipe .transformer .config .in_channels == in_features )
626
+
561
627
@unittest .skip ("Not supported in Flux." )
562
628
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
563
629
pass
0 commit comments