36
36
numpy_cosine_similarity_distance ,
37
37
require_big_gpu_with_torch_cuda ,
38
38
require_peft_backend ,
39
- require_peft_version_greater ,
40
39
require_torch_gpu ,
41
40
slow ,
42
41
torch_device ,
@@ -331,7 +330,8 @@ def test_lora_parameter_expanded_shapes(self):
331
330
}
332
331
with CaptureLogger (logger ) as cap_logger :
333
332
pipe .load_lora_weights (lora_state_dict , "adapter-1" )
334
- self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
333
+
334
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
335
335
336
336
lora_out = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
337
337
@@ -340,85 +340,32 @@ def test_lora_parameter_expanded_shapes(self):
340
340
self .assertTrue (pipe .transformer .config .in_channels == 2 * in_features )
341
341
self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
342
342
343
- @require_peft_version_greater ("0.13.2" )
344
- def test_lora_B_bias (self ):
345
- components , _ , denoiser_lora_config = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
346
- pipe = self .pipeline_class (** components )
347
- pipe = pipe .to (torch_device )
348
- pipe .set_progress_bar_config (disable = None )
349
-
350
- # keep track of the bias values of the base layers to perform checks later.
351
- bias_values = {}
352
- for name , module in pipe .transformer .named_modules ():
353
- if any (k in name for k in ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
354
- if module .bias is not None :
355
- bias_values [name ] = module .bias .data .clone ()
356
-
357
- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
358
-
359
- logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
360
- logger .setLevel (logging .INFO )
361
-
362
- original_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
363
-
364
- denoiser_lora_config .lora_bias = False
365
- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
366
- lora_bias_false_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
367
- pipe .delete_adapters ("adapter-1" )
368
-
369
- denoiser_lora_config .lora_bias = True
370
- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
371
- lora_bias_true_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
372
-
373
- self .assertFalse (np .allclose (original_output , lora_bias_false_output , atol = 1e-3 , rtol = 1e-3 ))
374
- self .assertFalse (np .allclose (original_output , lora_bias_true_output , atol = 1e-3 , rtol = 1e-3 ))
375
- self .assertFalse (np .allclose (lora_bias_false_output , lora_bias_true_output , atol = 1e-3 , rtol = 1e-3 ))
376
-
377
- # for now this is flux control lora specific but can be generalized later and added to ./utils.py
378
- def test_correct_lora_configs_with_different_ranks (self ):
379
- components , _ , denoiser_lora_config = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
343
+ # Testing opposite direction where the LoRA params are zero-padded.
344
+ components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
380
345
pipe = self .pipeline_class (** components )
381
346
pipe = pipe .to (torch_device )
382
347
pipe .set_progress_bar_config (disable = None )
383
- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
384
-
385
- original_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
386
-
387
- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
388
- lora_output_same_rank = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
389
- pipe .transformer .delete_adapters ("adapter-1" )
390
-
391
- # change the rank_pattern
392
- updated_rank = denoiser_lora_config .r * 2
393
- denoiser_lora_config .rank_pattern = {"single_transformer_blocks.0.attn.to_k" : updated_rank }
394
- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
395
- assert pipe .transformer .peft_config ["adapter-1" ].rank_pattern == {
396
- "single_transformer_blocks.0.attn.to_k" : updated_rank
348
+ dummy_lora_A = torch .nn .Linear (1 , rank , bias = False )
349
+ dummy_lora_B = torch .nn .Linear (rank , out_features , bias = False )
350
+ lora_state_dict = {
351
+ "transformer.x_embedder.lora_A.weight" : dummy_lora_A .weight ,
352
+ "transformer.x_embedder.lora_B.weight" : dummy_lora_B .weight ,
397
353
}
354
+ with CaptureLogger (logger ) as cap_logger :
355
+ pipe .load_lora_weights (lora_state_dict , "adapter-1" )
398
356
399
- lora_output_diff_rank = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
400
-
401
- self .assertTrue (not np .allclose (original_output , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
402
- self .assertTrue (not np .allclose (lora_output_diff_rank , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
403
- pipe .transformer .delete_adapters ("adapter-1" )
404
-
405
- # similarly change the alpha_pattern
406
- updated_alpha = denoiser_lora_config .lora_alpha * 2
407
- denoiser_lora_config .alpha_pattern = {"single_transformer_blocks.0.attn.to_k" : updated_alpha }
408
- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
409
- assert pipe .transformer .peft_config ["adapter-1" ].alpha_pattern == {
410
- "single_transformer_blocks.0.attn.to_k" : updated_alpha
411
- }
357
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
412
358
413
- lora_output_diff_alpha = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
359
+ lora_out = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
414
360
415
- self .assertTrue (not np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
416
- self .assertTrue (not np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
361
+ self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
362
+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == 2 * in_features )
363
+ self .assertTrue (pipe .transformer .config .in_channels == 2 * in_features )
364
+ self .assertTrue ("The following LoRA modules were zero padded to match the state dict of" in cap_logger .out )
417
365
418
- def test_lora_expanding_shape_with_normal_lora (self ):
419
- # This test checks if it works when a lora with expanded shapes (like control loras) but
420
- # another lora with correct shapes is loaded. The opposite direction isn't supported and is
421
- # tested with it.
366
+ def test_normal_lora_with_expanded_lora_raises_error (self ):
367
+ # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
368
+ # load shape expanded LoRA (such as Control LoRA).
422
369
components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
423
370
424
371
# Change the transformer config to mimic a real use case.
0 commit comments