@@ -159,6 +159,7 @@ def test_with_alpha_in_state_dict(self):
159159        )
160160        self .assertFalse (np .allclose (images_lora_with_alpha , images_lora , atol = 1e-3 , rtol = 1e-3 ))
161161
162+     # flux control lora specific 
162163    def  test_with_norm_in_state_dict (self ):
163164        components , _ , denoiser_lora_config  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
164165        pipe  =  self .pipeline_class (** components )
@@ -210,6 +211,7 @@ def test_with_norm_in_state_dict(self):
210211            cap_logger .out .startswith ("Unsupported keys found in state dict when trying to load normalization layers" )
211212        )
212213
214+     # flux control lora specific 
213215    def  test_lora_parameter_expanded_shapes (self ):
214216        components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
215217        pipe  =  self .pipeline_class (** components )
@@ -254,6 +256,7 @@ def test_lora_parameter_expanded_shapes(self):
254256        with  self .assertRaises (NotImplementedError ):
255257            pipe .load_lora_weights (lora_state_dict , "adapter-1" )
256258
259+     # flux control lora specific 
257260    @require_peft_version_greater ("0.13.2" ) 
258261    def  test_lora_B_bias (self ):
259262        components , _ , denoiser_lora_config  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
@@ -275,12 +278,53 @@ def test_lora_B_bias(self):
275278
276279        denoiser_lora_config .lora_bias  =  True 
277280        pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
278-         lora_bias_true_output  =  pipe (** inputs )[0 ]
281+         lora_bias_true_output  =  pipe (** inputs ,  generator = torch . manual_seed ( 0 ) )[0 ]
279282
280283        self .assertFalse (np .allclose (original_output , lora_bias_false_output , atol = 1e-3 , rtol = 1e-3 ))
281284        self .assertFalse (np .allclose (original_output , lora_bias_true_output , atol = 1e-3 , rtol = 1e-3 ))
282285        self .assertFalse (np .allclose (lora_bias_false_output , lora_bias_true_output , atol = 1e-3 , rtol = 1e-3 ))
283286
287+     # for now this is flux control lora specific but can be generalized later and added to ./utils.py 
288+     def  test_correct_lora_configs_with_different_ranks (self ):
289+         components , _ , denoiser_lora_config  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
290+         pipe  =  self .pipeline_class (** components )
291+         pipe  =  pipe .to (torch_device )
292+         pipe .set_progress_bar_config (disable = None )
293+         _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
294+ 
295+         original_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
296+ 
297+         pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
298+         lora_output_same_rank  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
299+         pipe .transformer .delete_adapters ("adapter-1" )
300+ 
301+         # change the rank_pattern 
302+         updated_rank  =  denoiser_lora_config .r  *  2 
303+         denoiser_lora_config .rank_pattern  =  {"single_transformer_blocks.0.attn.to_k" : updated_rank }
304+         pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
305+         assert  pipe .transformer .peft_config ["adapter-1" ].rank_pattern  ==  {
306+             "single_transformer_blocks.0.attn.to_k" : updated_rank 
307+         }
308+ 
309+         lora_output_diff_rank  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
310+ 
311+         self .assertTrue (not  np .allclose (original_output , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
312+         self .assertTrue (not  np .allclose (lora_output_diff_rank , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
313+         pipe .transformer .delete_adapters ("adapter-1" )
314+ 
315+         # similarly change the alpha_pattern 
316+         updated_alpha  =  denoiser_lora_config .lora_alpha  *  2 
317+         denoiser_lora_config .alpha_pattern  =  {"single_transformer_blocks.0.attn.to_k" : updated_alpha }
318+         pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
319+         assert  pipe .transformer .peft_config ["adapter-1" ].alpha_pattern  ==  {
320+             "single_transformer_blocks.0.attn.to_k" : updated_alpha 
321+         }
322+ 
323+         lora_output_diff_alpha  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
324+ 
325+         self .assertTrue (not  np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
326+         self .assertTrue (not  np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
327+ 
284328    @unittest .skip ("Not supported in Flux." ) 
285329    def  test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
286330        pass 
0 commit comments