3636 numpy_cosine_similarity_distance ,
3737 require_big_gpu_with_torch_cuda ,
3838 require_peft_backend ,
39- require_peft_version_greater ,
4039 require_torch_gpu ,
4140 slow ,
4241 torch_device ,
@@ -331,7 +330,8 @@ def test_lora_parameter_expanded_shapes(self):
331330 }
332331 with CaptureLogger (logger ) as cap_logger :
333332 pipe .load_lora_weights (lora_state_dict , "adapter-1" )
334- self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
333+
334+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
335335
336336 lora_out = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
337337
@@ -340,85 +340,32 @@ def test_lora_parameter_expanded_shapes(self):
340340 self .assertTrue (pipe .transformer .config .in_channels == 2 * in_features )
341341 self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
342342
343- @require_peft_version_greater ("0.13.2" )
344- def test_lora_B_bias (self ):
345- components , _ , denoiser_lora_config = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
346- pipe = self .pipeline_class (** components )
347- pipe = pipe .to (torch_device )
348- pipe .set_progress_bar_config (disable = None )
349-
350- # keep track of the bias values of the base layers to perform checks later.
351- bias_values = {}
352- for name , module in pipe .transformer .named_modules ():
353- if any (k in name for k in ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
354- if module .bias is not None :
355- bias_values [name ] = module .bias .data .clone ()
356-
357- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
358-
359- logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
360- logger .setLevel (logging .INFO )
361-
362- original_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
363-
364- denoiser_lora_config .lora_bias = False
365- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
366- lora_bias_false_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
367- pipe .delete_adapters ("adapter-1" )
368-
369- denoiser_lora_config .lora_bias = True
370- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
371- lora_bias_true_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
372-
373- self .assertFalse (np .allclose (original_output , lora_bias_false_output , atol = 1e-3 , rtol = 1e-3 ))
374- self .assertFalse (np .allclose (original_output , lora_bias_true_output , atol = 1e-3 , rtol = 1e-3 ))
375- self .assertFalse (np .allclose (lora_bias_false_output , lora_bias_true_output , atol = 1e-3 , rtol = 1e-3 ))
376-
377- # for now this is flux control lora specific but can be generalized later and added to ./utils.py
378- def test_correct_lora_configs_with_different_ranks (self ):
379- components , _ , denoiser_lora_config = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
343+ # Testing opposite direction where the LoRA params are zero-padded.
344+ components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
380345 pipe = self .pipeline_class (** components )
381346 pipe = pipe .to (torch_device )
382347 pipe .set_progress_bar_config (disable = None )
383- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
384-
385- original_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
386-
387- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
388- lora_output_same_rank = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
389- pipe .transformer .delete_adapters ("adapter-1" )
390-
391- # change the rank_pattern
392- updated_rank = denoiser_lora_config .r * 2
393- denoiser_lora_config .rank_pattern = {"single_transformer_blocks.0.attn.to_k" : updated_rank }
394- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
395- assert pipe .transformer .peft_config ["adapter-1" ].rank_pattern == {
396- "single_transformer_blocks.0.attn.to_k" : updated_rank
348+ dummy_lora_A = torch .nn .Linear (1 , rank , bias = False )
349+ dummy_lora_B = torch .nn .Linear (rank , out_features , bias = False )
350+ lora_state_dict = {
351+ "transformer.x_embedder.lora_A.weight" : dummy_lora_A .weight ,
352+ "transformer.x_embedder.lora_B.weight" : dummy_lora_B .weight ,
397353 }
354+ with CaptureLogger (logger ) as cap_logger :
355+ pipe .load_lora_weights (lora_state_dict , "adapter-1" )
398356
399- lora_output_diff_rank = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
400-
401- self .assertTrue (not np .allclose (original_output , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
402- self .assertTrue (not np .allclose (lora_output_diff_rank , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
403- pipe .transformer .delete_adapters ("adapter-1" )
404-
405- # similarly change the alpha_pattern
406- updated_alpha = denoiser_lora_config .lora_alpha * 2
407- denoiser_lora_config .alpha_pattern = {"single_transformer_blocks.0.attn.to_k" : updated_alpha }
408- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
409- assert pipe .transformer .peft_config ["adapter-1" ].alpha_pattern == {
410- "single_transformer_blocks.0.attn.to_k" : updated_alpha
411- }
357+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
412358
413- lora_output_diff_alpha = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
359+ lora_out = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
414360
415- self .assertTrue (not np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
416- self .assertTrue (not np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
361+ self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
362+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == 2 * in_features )
363+ self .assertTrue (pipe .transformer .config .in_channels == 2 * in_features )
364+ self .assertTrue ("The following LoRA modules were zero padded to match the state dict of" in cap_logger .out )
417365
418- def test_lora_expanding_shape_with_normal_lora (self ):
419- # This test checks if it works when a lora with expanded shapes (like control loras) but
420- # another lora with correct shapes is loaded. The opposite direction isn't supported and is
421- # tested with it.
366+ def test_normal_lora_with_expanded_lora_raises_error (self ):
367+ # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
368+ # load shape expanded LoRA (such as Control LoRA).
422369 components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
423370
424371 # Change the transformer config to mimic a real use case.
0 commit comments