@@ -330,7 +330,8 @@ def test_lora_parameter_expanded_shapes(self):
330330        }
331331        with  CaptureLogger (logger ) as  cap_logger :
332332            pipe .load_lora_weights (lora_state_dict , "adapter-1" )
333-             self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
333+ 
334+         self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
334335
335336        lora_out  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
336337
@@ -339,6 +340,7 @@ def test_lora_parameter_expanded_shapes(self):
339340        self .assertTrue (pipe .transformer .config .in_channels  ==  2  *  in_features )
340341        self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
341342
343+         # Testing opposite direction where the LoRA params are zero-padded. 
342344        components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
343345        pipe  =  self .pipeline_class (** components )
344346        pipe  =  pipe .to (torch_device )
@@ -349,15 +351,21 @@ def test_lora_parameter_expanded_shapes(self):
349351            "transformer.x_embedder.lora_A.weight" : dummy_lora_A .weight ,
350352            "transformer.x_embedder.lora_B.weight" : dummy_lora_B .weight ,
351353        }
352-         # We should error out because lora input features is less than original. We only 
353-         # support expanding the module, not shrinking it 
354-         with  self .assertRaises (NotImplementedError ):
354+         with  CaptureLogger (logger ) as  cap_logger :
355355            pipe .load_lora_weights (lora_state_dict , "adapter-1" )
356356
357-     def  test_lora_expanding_shape_with_normal_lora_raises_error (self ):
358-         # TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but 
359-         # another lora with correct shapes is loaded. This is not supported at the moment and should raise an error. 
360-         # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180 
357+         self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
358+ 
359+         lora_out  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
360+ 
361+         self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
362+         self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  2  *  in_features )
363+         self .assertTrue (pipe .transformer .config .in_channels  ==  2  *  in_features )
364+         self .assertTrue ("The following LoRA modules were zero padded to match the state dict of"  in  cap_logger .out )
365+ 
366+     def  test_normal_lora_with_expanded_lora_raises_error (self ):
367+         # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then 
368+         # load shape expanded LoRA (such as Control LoRA). 
361369        components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
362370
363371        # Change the transformer config to mimic a real use case. 
0 commit comments