@@ -430,6 +430,68 @@ def test_correct_lora_configs_with_different_ranks(self):
430430 self .assertTrue (not np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
431431 self .assertTrue (not np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
432432
433+ def test_lora_unload_with_parameter_expanded_shapes (self ):
434+ components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
435+
436+ logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
437+ logger .setLevel (logging .DEBUG )
438+
439+ # Change the transformer config to mimic a real use case.
440+ num_channels_without_control = 4
441+ transformer = FluxTransformer2DModel .from_config (
442+ components ["transformer" ].config , in_channels = num_channels_without_control
443+ ).to (torch_device )
444+ self .assertTrue (
445+ transformer .config .in_channels == num_channels_without_control ,
446+ f"Expected { num_channels_without_control } channels in the modified transformer but has { transformer .config .in_channels = } " ,
447+ )
448+
449+ # This should be initialize with a Flux pipeline variant that doesn't accept `control_image`.
450+ components ["transformer" ] = transformer
451+ pipe = FluxPipeline (** components )
452+ pipe = pipe .to (torch_device )
453+ pipe .set_progress_bar_config (disable = None )
454+
455+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
456+ control_image = inputs .pop ("control_image" )
457+ original_out = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
458+
459+ control_pipe = self .pipeline_class (** components )
460+ out_features , in_features = control_pipe .transformer .x_embedder .weight .shape
461+ rank = 4
462+
463+ dummy_lora_A = torch .nn .Linear (2 * in_features , rank , bias = False )
464+ dummy_lora_B = torch .nn .Linear (rank , out_features , bias = False )
465+ lora_state_dict = {
466+ "transformer.x_embedder.lora_A.weight" : dummy_lora_A .weight ,
467+ "transformer.x_embedder.lora_B.weight" : dummy_lora_B .weight ,
468+ }
469+ with CaptureLogger (logger ) as cap_logger :
470+ control_pipe .load_lora_weights (lora_state_dict , "adapter-1" )
471+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
472+
473+ inputs ["control_image" ] = control_image
474+ lora_out = control_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
475+
476+ self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
477+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == 2 * in_features )
478+ self .assertTrue (pipe .transformer .config .in_channels == 2 * in_features )
479+ self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
480+
481+ control_pipe .unload_lora_weights ()
482+ loaded_pipe = FluxPipeline .from_pipe (control_pipe )
483+ self .assertTrue (
484+ loaded_pipe .transformer .config .in_channels == num_channels_without_control ,
485+ f"Expected { num_channels_without_control } channels in the modified transformer but has { loaded_pipe .transformer .config .in_channels = } " ,
486+ )
487+ inputs .pop ("control_image" )
488+ unloaded_lora_out = loaded_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
489+
490+ self .assertFalse (np .allclose (unloaded_lora_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
491+ self .assertTrue (np .allclose (unloaded_lora_out , original_out , atol = 1e-4 , rtol = 1e-4 ))
492+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == in_features )
493+ self .assertTrue (pipe .transformer .config .in_channels == in_features )
494+
433495 @unittest .skip ("Not supported in Flux." )
434496 def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
435497 pass
0 commit comments