@@ -430,10 +430,10 @@ def test_correct_lora_configs_with_different_ranks(self):
430430        self .assertTrue (not  np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
431431        self .assertTrue (not  np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
432432
433-     def  test_lora_expanding_shape_with_normal_lora_raises_error (self ):
434-         # TODO:  This test checks if an error is raised  when a lora expands  shapes (like control loras) but 
435-         # another lora with correct shapes is loaded. This is not  supported at the moment  and should raise an error.  
436-         # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180  
433+     def  test_lora_expanding_shape_with_normal_lora (self ):
434+         # This test checks if it works  when a lora with expanded  shapes (like control loras) but 
435+         # another lora with correct shapes is loaded. The opposite direction isn't  supported and is  
436+         # tested with it.  
437437        components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
438438
439439        # Change the transformer config to mimic a real use case. 
@@ -478,21 +478,16 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
478478            "transformer.x_embedder.lora_B.weight" : normal_lora_B .weight ,
479479        }
480480
481-         # The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct 
482-         # input features before expansion. This should raise an error about the weight shapes being incompatible. 
483-         self .assertRaisesRegex (
484-             RuntimeError ,
485-             "size mismatch for x_embedder.lora_A.adapter-2.weight" ,
486-             pipe .load_lora_weights ,
487-             lora_state_dict ,
488-             "adapter-2" ,
489-         )
490-         # We should have `adapter-1` as the only adapter. 
491-         self .assertTrue (pipe .get_active_adapters () ==  ["adapter-1" ])
481+         with  CaptureLogger (logger ) as  cap_logger :
482+             pipe .load_lora_weights (lora_state_dict , "adapter-2" )
483+ 
484+         self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
485+         self .assertTrue (pipe .get_active_adapters () ==  ["adapter-2" ])
486+ 
487+         lora_output_2  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
492488
493-         # Check if the output is the same after lora loading error 
494-         lora_output_after_error  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
495-         self .assertTrue (np .allclose (lora_output , lora_output_after_error , atol = 1e-3 , rtol = 1e-3 ))
489+         self .assertTrue ("Found some LoRA modules for which the weights were zero-padded"  in  cap_logger .out )
490+         self .assertFalse (np .allclose (lora_output , lora_output_2 , atol = 1e-3 , rtol = 1e-3 ))
496491
497492        # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. 
498493        # This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the 
@@ -524,8 +519,8 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
524519
525520        with  CaptureLogger (logger ) as  cap_logger :
526521            pipe .load_lora_weights (lora_state_dict , "adapter-1" )
527-             self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
528522
523+         self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
529524        self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  in_features )
530525        self .assertTrue (pipe .transformer .config .in_channels  ==  in_features )
531526        self .assertFalse (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
@@ -535,17 +530,107 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
535530            "transformer.x_embedder.lora_B.weight" : shape_expander_lora_B .weight ,
536531        }
537532
538-         # We should check for input shapes being incompatible here. But because above mentioned issue is 
539-         # not a supported use case, and because of the PEFT renaming, we will currently have a shape 
540-         # mismatch error. 
533+         # We should check for input shapes being incompatible here. 
541534        self .assertRaisesRegex (
542535            RuntimeError ,
543-             "size mismatch for  x_embedder.lora_A.adapter-2 .weight" ,
536+             "x_embedder.lora_A.weight" ,
544537            pipe .load_lora_weights ,
545538            lora_state_dict ,
546539            "adapter-2" ,
547540        )
548541
542+     def  test_fuse_expanded_lora_with_regular_lora (self ):
543+         # This test checks if it works when a lora with expanded shapes (like control loras) but 
544+         # another lora with correct shapes is loaded. The opposite direction isn't supported and is 
545+         # tested with it. 
546+         components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
547+ 
548+         # Change the transformer config to mimic a real use case. 
549+         num_channels_without_control  =  4 
550+         transformer  =  FluxTransformer2DModel .from_config (
551+             components ["transformer" ].config , in_channels = num_channels_without_control 
552+         ).to (torch_device )
553+         components ["transformer" ] =  transformer 
554+ 
555+         pipe  =  self .pipeline_class (** components )
556+         pipe  =  pipe .to (torch_device )
557+         pipe .set_progress_bar_config (disable = None )
558+ 
559+         logger  =  logging .get_logger ("diffusers.loaders.lora_pipeline" )
560+         logger .setLevel (logging .DEBUG )
561+ 
562+         out_features , in_features  =  pipe .transformer .x_embedder .weight .shape 
563+         rank  =  4 
564+ 
565+         shape_expander_lora_A  =  torch .nn .Linear (2  *  in_features , rank , bias = False )
566+         shape_expander_lora_B  =  torch .nn .Linear (rank , out_features , bias = False )
567+         lora_state_dict  =  {
568+             "transformer.x_embedder.lora_A.weight" : shape_expander_lora_A .weight ,
569+             "transformer.x_embedder.lora_B.weight" : shape_expander_lora_B .weight ,
570+         }
571+         pipe .load_lora_weights (lora_state_dict , "adapter-1" )
572+         self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
573+ 
574+         _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
575+         lora_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
576+ 
577+         normal_lora_A  =  torch .nn .Linear (in_features , rank , bias = False )
578+         normal_lora_B  =  torch .nn .Linear (rank , out_features , bias = False )
579+         lora_state_dict  =  {
580+             "transformer.x_embedder.lora_A.weight" : normal_lora_A .weight ,
581+             "transformer.x_embedder.lora_B.weight" : normal_lora_B .weight ,
582+         }
583+ 
584+         pipe .load_lora_weights (lora_state_dict , "adapter-2" )
585+         self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
586+ 
587+         lora_output_2  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
588+ 
589+         pipe .set_adapters (["adapter-1" , "adapter-2" ], [1.0 , 1.0 ])
590+         lora_output_3  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
591+ 
592+         self .assertFalse (np .allclose (lora_output , lora_output_2 , atol = 1e-3 , rtol = 1e-3 ))
593+         self .assertFalse (np .allclose (lora_output , lora_output_3 , atol = 1e-3 , rtol = 1e-3 ))
594+         self .assertFalse (np .allclose (lora_output_2 , lora_output_3 , atol = 1e-3 , rtol = 1e-3 ))
595+ 
596+         pipe .fuse_lora (lora_scale = 1.0 , adapter_names = ["adapter-1" , "adapter-2" ])
597+         lora_output_4  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
598+         self .assertTrue (np .allclose (lora_output_3 , lora_output_4 , atol = 1e-3 , rtol = 1e-3 ))
599+ 
600+     def  test_load_regular_lora (self ):
601+         # This test checks if a regular lora (think of one trained Flux.1 Dev for example) can be loaded 
602+         # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those 
603+         # transformers include Flux Fill, Flux Control, etc. 
604+         components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
605+         pipe  =  self .pipeline_class (** components )
606+         pipe  =  pipe .to (torch_device )
607+         pipe .set_progress_bar_config (disable = None )
608+         _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
609+ 
610+         original_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
611+ 
612+         out_features , in_features  =  pipe .transformer .x_embedder .weight .shape 
613+         rank  =  4 
614+         in_features  =  in_features  //  2   # to mimic the Flux.1-Dev LoRA. 
615+         normal_lora_A  =  torch .nn .Linear (in_features , rank , bias = False )
616+         normal_lora_B  =  torch .nn .Linear (rank , out_features , bias = False )
617+         lora_state_dict  =  {
618+             "transformer.x_embedder.lora_A.weight" : normal_lora_A .weight ,
619+             "transformer.x_embedder.lora_B.weight" : normal_lora_B .weight ,
620+         }
621+ 
622+         logger  =  logging .get_logger ("diffusers.loaders.lora_pipeline" )
623+         logger .setLevel (logging .INFO )
624+         with  CaptureLogger (logger ) as  cap_logger :
625+             pipe .load_lora_weights (lora_state_dict , "adapter-1" )
626+         self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
627+ 
628+         lora_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
629+ 
630+         self .assertTrue ("Found some LoRA modules for which the weights were zero-padded"  in  cap_logger .out )
631+         self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  in_features  *  2 )
632+         self .assertFalse (np .allclose (original_output , lora_output , atol = 1e-3 , rtol = 1e-3 ))
633+ 
549634    @unittest .skip ("Not supported in Flux." ) 
550635    def  test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
551636        pass 
0 commit comments