2121import  numpy  as  np 
2222import  safetensors .torch 
2323import  torch 
24+ from  PIL  import  Image 
2425from  transformers  import  AutoTokenizer , CLIPTextModel , CLIPTokenizer , T5EncoderModel 
2526
26- from  diffusers  import  FlowMatchEulerDiscreteScheduler , FluxPipeline , FluxTransformer2DModel 
27+ from  diffusers  import  FlowMatchEulerDiscreteScheduler , FluxControlPipeline ,  FluxPipeline , FluxTransformer2DModel 
2728from  diffusers .utils  import  logging 
2829from  diffusers .utils .testing_utils  import  (
2930    CaptureLogger ,
@@ -159,7 +160,80 @@ def test_with_alpha_in_state_dict(self):
159160        )
160161        self .assertFalse (np .allclose (images_lora_with_alpha , images_lora , atol = 1e-3 , rtol = 1e-3 ))
161162
162-     # flux control lora specific 
163+     @unittest .skip ("Not supported in Flux." ) 
164+     def  test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
165+         pass 
166+ 
167+     @unittest .skip ("Not supported in Flux." ) 
168+     def  test_modify_padding_mode (self ):
169+         pass 
170+ 
171+ 
172+ class  FluxControlLoRATests (unittest .TestCase , PeftLoraLoaderMixinTests ):
173+     pipeline_class  =  FluxControlPipeline 
174+     scheduler_cls  =  FlowMatchEulerDiscreteScheduler ()
175+     scheduler_kwargs  =  {}
176+     scheduler_classes  =  [FlowMatchEulerDiscreteScheduler ]
177+     transformer_kwargs  =  {
178+         "patch_size" : 1 ,
179+         "in_channels" : 8 ,
180+         "out_channels" : 4 ,
181+         "num_layers" : 1 ,
182+         "num_single_layers" : 1 ,
183+         "attention_head_dim" : 16 ,
184+         "num_attention_heads" : 2 ,
185+         "joint_attention_dim" : 32 ,
186+         "pooled_projection_dim" : 32 ,
187+         "axes_dims_rope" : [4 , 4 , 8 ],
188+     }
189+     transformer_cls  =  FluxTransformer2DModel 
190+     vae_kwargs  =  {
191+         "sample_size" : 32 ,
192+         "in_channels" : 3 ,
193+         "out_channels" : 3 ,
194+         "block_out_channels" : (4 ,),
195+         "layers_per_block" : 1 ,
196+         "latent_channels" : 1 ,
197+         "norm_num_groups" : 1 ,
198+         "use_quant_conv" : False ,
199+         "use_post_quant_conv" : False ,
200+         "shift_factor" : 0.0609 ,
201+         "scaling_factor" : 1.5035 ,
202+     }
203+     has_two_text_encoders  =  True 
204+     tokenizer_cls , tokenizer_id  =  CLIPTokenizer , "peft-internal-testing/tiny-clip-text-2" 
205+     tokenizer_2_cls , tokenizer_2_id  =  AutoTokenizer , "hf-internal-testing/tiny-random-t5" 
206+     text_encoder_cls , text_encoder_id  =  CLIPTextModel , "peft-internal-testing/tiny-clip-text-2" 
207+     text_encoder_2_cls , text_encoder_2_id  =  T5EncoderModel , "hf-internal-testing/tiny-random-t5" 
208+ 
209+     @property  
210+     def  output_shape (self ):
211+         return  (1 , 8 , 8 , 3 )
212+ 
213+     def  get_dummy_inputs (self , with_generator = True ):
214+         batch_size  =  1 
215+         sequence_length  =  10 
216+         num_channels  =  4 
217+         sizes  =  (32 , 32 )
218+ 
219+         generator  =  torch .manual_seed (0 )
220+         noise  =  floats_tensor ((batch_size , num_channels ) +  sizes )
221+         input_ids  =  torch .randint (1 , sequence_length , size = (batch_size , sequence_length ), generator = generator )
222+ 
223+         pipeline_inputs  =  {
224+             "prompt" : "A painting of a squirrel eating a burger" ,
225+             "control_image" : Image .fromarray (np .random .randint (0 , 255 , size = (32 , 32 , 3 ), dtype = "uint8" )),
226+             "num_inference_steps" : 4 ,
227+             "guidance_scale" : 0.0 ,
228+             "height" : 8 ,
229+             "width" : 8 ,
230+             "output_type" : "np" ,
231+         }
232+         if  with_generator :
233+             pipeline_inputs .update ({"generator" : generator })
234+ 
235+         return  noise , input_ids , pipeline_inputs 
236+ 
163237    def  test_with_norm_in_state_dict (self ):
164238        components , _ , denoiser_lora_config  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
165239        pipe  =  self .pipeline_class (** components )
@@ -184,7 +258,7 @@ def test_with_norm_in_state_dict(self):
184258
185259                with  CaptureLogger (logger ) as  cap_logger :
186260                    pipe .load_lora_weights (norm_state_dict )
187-                      lora_load_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
261+                 lora_load_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
188262
189263                self .assertTrue (
190264                    cap_logger .out .startswith (
@@ -211,18 +285,38 @@ def test_with_norm_in_state_dict(self):
211285            cap_logger .out .startswith ("Unsupported keys found in state dict when trying to load normalization layers" )
212286        )
213287
214-     # flux control lora specific 
215288    def  test_lora_parameter_expanded_shapes (self ):
216289        components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
217290        pipe  =  self .pipeline_class (** components )
218291        pipe  =  pipe .to (torch_device )
219292        pipe .set_progress_bar_config (disable = None )
220293
221294        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
295+         original_out  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
222296
223297        logger  =  logging .get_logger ("diffusers.loaders.lora_pipeline" )
224298        logger .setLevel (logging .DEBUG )
225299
300+         # Change the transformer config to mimic a real use case. 
301+         num_channels_without_control  =  4 
302+         transformer  =  FluxTransformer2DModel .from_config (
303+             components ["transformer" ].config , in_channels = num_channels_without_control 
304+         ).to (torch_device )
305+         self .assertTrue (
306+             transformer .config .in_channels  ==  num_channels_without_control ,
307+             f"Expected { num_channels_without_control } { transformer .config .in_channels = }  ,
308+         )
309+ 
310+         original_transformer_state_dict  =  pipe .transformer .state_dict ()
311+         x_embedder_weight  =  original_transformer_state_dict .pop ("x_embedder.weight" )
312+         incompatible_keys  =  transformer .load_state_dict (original_transformer_state_dict , strict = False )
313+         self .assertTrue (
314+             "x_embedder.weight"  in  incompatible_keys .missing_keys ,
315+             "Could not find x_embedder.weight in the missing keys." ,
316+         )
317+         transformer .x_embedder .weight .data .copy_ (x_embedder_weight [..., :num_channels_without_control ])
318+         pipe .transformer  =  transformer 
319+ 
226320        out_features , in_features  =  pipe .transformer .x_embedder .weight .shape 
227321        rank  =  4 
228322
@@ -234,11 +328,13 @@ def test_lora_parameter_expanded_shapes(self):
234328        }
235329        with  CaptureLogger (logger ) as  cap_logger :
236330            pipe .load_lora_weights (lora_state_dict , "adapter-1" )
331+             self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
237332
333+         lora_out  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
334+ 
335+         self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
238336        self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  2  *  in_features )
239337        self .assertTrue (pipe .transformer .config .in_channels  ==  2  *  in_features )
240- 
241-         pipe .delete_adapters ("adapter-1" )
242338        self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
243339
244340        components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
@@ -256,14 +352,20 @@ def test_lora_parameter_expanded_shapes(self):
256352        with  self .assertRaises (NotImplementedError ):
257353            pipe .load_lora_weights (lora_state_dict , "adapter-1" )
258354
259-     # flux control lora specific 
260355    @require_peft_version_greater ("0.13.2" ) 
261356    def  test_lora_B_bias (self ):
262357        components , _ , denoiser_lora_config  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
263358        pipe  =  self .pipeline_class (** components )
264359        pipe  =  pipe .to (torch_device )
265360        pipe .set_progress_bar_config (disable = None )
266361
362+         # keep track of the bias values of the base layers to perform checks later. 
363+         bias_values  =  {}
364+         for  name , module  in  pipe .transformer .named_modules ():
365+             if  any (k  in  name  for  k  in  ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
366+                 if  module .bias  is  not None :
367+                     bias_values [name ] =  module .bias .data .clone ()
368+ 
267369        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
268370
269371        logger  =  logging .get_logger ("diffusers.loaders.lora_pipeline" )
0 commit comments