@@ -430,10 +430,10 @@ def test_correct_lora_configs_with_different_ranks(self):
430430 self .assertTrue (not np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
431431 self .assertTrue (not np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
432432
433- def test_lora_expanding_shape_with_normal_lora_raises_error (self ):
434- # TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
435- # another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
436- # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
433+ def test_lora_expanding_shape_with_normal_lora (self ):
434+ # This test checks if it works when a lora with expanded shapes (like control loras) but
435+ # another lora with correct shapes is loaded. The opposite direction isn't supported and is
436+ # tested with it.
437437 components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
438438
439439 # Change the transformer config to mimic a real use case.
@@ -478,21 +478,16 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
478478 "transformer.x_embedder.lora_B.weight" : normal_lora_B .weight ,
479479 }
480480
481- # The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
482- # input features before expansion. This should raise an error about the weight shapes being incompatible.
483- self .assertRaisesRegex (
484- RuntimeError ,
485- "size mismatch for x_embedder.lora_A.adapter-2.weight" ,
486- pipe .load_lora_weights ,
487- lora_state_dict ,
488- "adapter-2" ,
489- )
490- # We should have `adapter-1` as the only adapter.
491- self .assertTrue (pipe .get_active_adapters () == ["adapter-1" ])
481+ with CaptureLogger (logger ) as cap_logger :
482+ pipe .load_lora_weights (lora_state_dict , "adapter-2" )
483+
484+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
485+ self .assertTrue (pipe .get_active_adapters () == ["adapter-2" ])
486+
487+ lora_output_2 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
492488
493- # Check if the output is the same after lora loading error
494- lora_output_after_error = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
495- self .assertTrue (np .allclose (lora_output , lora_output_after_error , atol = 1e-3 , rtol = 1e-3 ))
489+ self .assertTrue ("Found some LoRA modules for which the weights were zero-padded" in cap_logger .out )
490+ self .assertFalse (np .allclose (lora_output , lora_output_2 , atol = 1e-3 , rtol = 1e-3 ))
496491
497492 # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
498493 # This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
@@ -524,8 +519,8 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
524519
525520 with CaptureLogger (logger ) as cap_logger :
526521 pipe .load_lora_weights (lora_state_dict , "adapter-1" )
527- self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
528522
523+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
529524 self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == in_features )
530525 self .assertTrue (pipe .transformer .config .in_channels == in_features )
531526 self .assertFalse (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
@@ -535,17 +530,107 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
535530 "transformer.x_embedder.lora_B.weight" : shape_expander_lora_B .weight ,
536531 }
537532
538- # We should check for input shapes being incompatible here. But because above mentioned issue is
539- # not a supported use case, and because of the PEFT renaming, we will currently have a shape
540- # mismatch error.
533+ # We should check for input shapes being incompatible here.
541534 self .assertRaisesRegex (
542535 RuntimeError ,
543- "size mismatch for x_embedder.lora_A.adapter-2 .weight" ,
536+ "x_embedder.lora_A.weight" ,
544537 pipe .load_lora_weights ,
545538 lora_state_dict ,
546539 "adapter-2" ,
547540 )
548541
542+ def test_fuse_expanded_lora_with_regular_lora (self ):
543+ # This test checks if it works when a lora with expanded shapes (like control loras) but
544+ # another lora with correct shapes is loaded. The opposite direction isn't supported and is
545+ # tested with it.
546+ components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
547+
548+ # Change the transformer config to mimic a real use case.
549+ num_channels_without_control = 4
550+ transformer = FluxTransformer2DModel .from_config (
551+ components ["transformer" ].config , in_channels = num_channels_without_control
552+ ).to (torch_device )
553+ components ["transformer" ] = transformer
554+
555+ pipe = self .pipeline_class (** components )
556+ pipe = pipe .to (torch_device )
557+ pipe .set_progress_bar_config (disable = None )
558+
559+ logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
560+ logger .setLevel (logging .DEBUG )
561+
562+ out_features , in_features = pipe .transformer .x_embedder .weight .shape
563+ rank = 4
564+
565+ shape_expander_lora_A = torch .nn .Linear (2 * in_features , rank , bias = False )
566+ shape_expander_lora_B = torch .nn .Linear (rank , out_features , bias = False )
567+ lora_state_dict = {
568+ "transformer.x_embedder.lora_A.weight" : shape_expander_lora_A .weight ,
569+ "transformer.x_embedder.lora_B.weight" : shape_expander_lora_B .weight ,
570+ }
571+ pipe .load_lora_weights (lora_state_dict , "adapter-1" )
572+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
573+
574+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
575+ lora_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
576+
577+ normal_lora_A = torch .nn .Linear (in_features , rank , bias = False )
578+ normal_lora_B = torch .nn .Linear (rank , out_features , bias = False )
579+ lora_state_dict = {
580+ "transformer.x_embedder.lora_A.weight" : normal_lora_A .weight ,
581+ "transformer.x_embedder.lora_B.weight" : normal_lora_B .weight ,
582+ }
583+
584+ pipe .load_lora_weights (lora_state_dict , "adapter-2" )
585+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
586+
587+ lora_output_2 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
588+
589+ pipe .set_adapters (["adapter-1" , "adapter-2" ], [1.0 , 1.0 ])
590+ lora_output_3 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
591+
592+ self .assertFalse (np .allclose (lora_output , lora_output_2 , atol = 1e-3 , rtol = 1e-3 ))
593+ self .assertFalse (np .allclose (lora_output , lora_output_3 , atol = 1e-3 , rtol = 1e-3 ))
594+ self .assertFalse (np .allclose (lora_output_2 , lora_output_3 , atol = 1e-3 , rtol = 1e-3 ))
595+
596+ pipe .fuse_lora (lora_scale = 1.0 , adapter_names = ["adapter-1" , "adapter-2" ])
597+ lora_output_4 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
598+ self .assertTrue (np .allclose (lora_output_3 , lora_output_4 , atol = 1e-3 , rtol = 1e-3 ))
599+
600+ def test_load_regular_lora (self ):
601+ # This test checks if a regular lora (think of one trained Flux.1 Dev for example) can be loaded
602+ # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
603+ # transformers include Flux Fill, Flux Control, etc.
604+ components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
605+ pipe = self .pipeline_class (** components )
606+ pipe = pipe .to (torch_device )
607+ pipe .set_progress_bar_config (disable = None )
608+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
609+
610+ original_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
611+
612+ out_features , in_features = pipe .transformer .x_embedder .weight .shape
613+ rank = 4
614+ in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
615+ normal_lora_A = torch .nn .Linear (in_features , rank , bias = False )
616+ normal_lora_B = torch .nn .Linear (rank , out_features , bias = False )
617+ lora_state_dict = {
618+ "transformer.x_embedder.lora_A.weight" : normal_lora_A .weight ,
619+ "transformer.x_embedder.lora_B.weight" : normal_lora_B .weight ,
620+ }
621+
622+ logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
623+ logger .setLevel (logging .INFO )
624+ with CaptureLogger (logger ) as cap_logger :
625+ pipe .load_lora_weights (lora_state_dict , "adapter-1" )
626+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
627+
628+ lora_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
629+
630+ self .assertTrue ("Found some LoRA modules for which the weights were zero-padded" in cap_logger .out )
631+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == in_features * 2 )
632+ self .assertFalse (np .allclose (original_output , lora_output , atol = 1e-3 , rtol = 1e-3 ))
633+
549634 @unittest .skip ("Not supported in Flux." )
550635 def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
551636 pass
0 commit comments