2121import numpy as np
2222import safetensors .torch
2323import torch
24+ from PIL import Image
2425from transformers import AutoTokenizer , CLIPTextModel , CLIPTokenizer , T5EncoderModel
2526
26- from diffusers import FlowMatchEulerDiscreteScheduler , FluxPipeline , FluxTransformer2DModel
27+ from diffusers import FlowMatchEulerDiscreteScheduler , FluxControlPipeline , FluxPipeline , FluxTransformer2DModel
2728from diffusers .utils import logging
2829from diffusers .utils .testing_utils import (
2930 CaptureLogger ,
@@ -159,7 +160,80 @@ def test_with_alpha_in_state_dict(self):
159160 )
160161 self .assertFalse (np .allclose (images_lora_with_alpha , images_lora , atol = 1e-3 , rtol = 1e-3 ))
161162
162- # flux control lora specific
163+ @unittest .skip ("Not supported in Flux." )
164+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
165+ pass
166+
167+ @unittest .skip ("Not supported in Flux." )
168+ def test_modify_padding_mode (self ):
169+ pass
170+
171+
172+ class FluxControlLoRATests (unittest .TestCase , PeftLoraLoaderMixinTests ):
173+ pipeline_class = FluxControlPipeline
174+ scheduler_cls = FlowMatchEulerDiscreteScheduler ()
175+ scheduler_kwargs = {}
176+ scheduler_classes = [FlowMatchEulerDiscreteScheduler ]
177+ transformer_kwargs = {
178+ "patch_size" : 1 ,
179+ "in_channels" : 8 ,
180+ "out_channels" : 4 ,
181+ "num_layers" : 1 ,
182+ "num_single_layers" : 1 ,
183+ "attention_head_dim" : 16 ,
184+ "num_attention_heads" : 2 ,
185+ "joint_attention_dim" : 32 ,
186+ "pooled_projection_dim" : 32 ,
187+ "axes_dims_rope" : [4 , 4 , 8 ],
188+ }
189+ transformer_cls = FluxTransformer2DModel
190+ vae_kwargs = {
191+ "sample_size" : 32 ,
192+ "in_channels" : 3 ,
193+ "out_channels" : 3 ,
194+ "block_out_channels" : (4 ,),
195+ "layers_per_block" : 1 ,
196+ "latent_channels" : 1 ,
197+ "norm_num_groups" : 1 ,
198+ "use_quant_conv" : False ,
199+ "use_post_quant_conv" : False ,
200+ "shift_factor" : 0.0609 ,
201+ "scaling_factor" : 1.5035 ,
202+ }
203+ has_two_text_encoders = True
204+ tokenizer_cls , tokenizer_id = CLIPTokenizer , "peft-internal-testing/tiny-clip-text-2"
205+ tokenizer_2_cls , tokenizer_2_id = AutoTokenizer , "hf-internal-testing/tiny-random-t5"
206+ text_encoder_cls , text_encoder_id = CLIPTextModel , "peft-internal-testing/tiny-clip-text-2"
207+ text_encoder_2_cls , text_encoder_2_id = T5EncoderModel , "hf-internal-testing/tiny-random-t5"
208+
209+ @property
210+ def output_shape (self ):
211+ return (1 , 8 , 8 , 3 )
212+
213+ def get_dummy_inputs (self , with_generator = True ):
214+ batch_size = 1
215+ sequence_length = 10
216+ num_channels = 4
217+ sizes = (32 , 32 )
218+
219+ generator = torch .manual_seed (0 )
220+ noise = floats_tensor ((batch_size , num_channels ) + sizes )
221+ input_ids = torch .randint (1 , sequence_length , size = (batch_size , sequence_length ), generator = generator )
222+
223+ pipeline_inputs = {
224+ "prompt" : "A painting of a squirrel eating a burger" ,
225+ "control_image" : Image .fromarray (np .random .randint (0 , 255 , size = (32 , 32 , 3 ), dtype = "uint8" )),
226+ "num_inference_steps" : 4 ,
227+ "guidance_scale" : 0.0 ,
228+ "height" : 8 ,
229+ "width" : 8 ,
230+ "output_type" : "np" ,
231+ }
232+ if with_generator :
233+ pipeline_inputs .update ({"generator" : generator })
234+
235+ return noise , input_ids , pipeline_inputs
236+
163237 def test_with_norm_in_state_dict (self ):
164238 components , _ , denoiser_lora_config = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
165239 pipe = self .pipeline_class (** components )
@@ -184,7 +258,7 @@ def test_with_norm_in_state_dict(self):
184258
185259 with CaptureLogger (logger ) as cap_logger :
186260 pipe .load_lora_weights (norm_state_dict )
187- lora_load_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
261+ lora_load_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
188262
189263 self .assertTrue (
190264 cap_logger .out .startswith (
@@ -211,18 +285,38 @@ def test_with_norm_in_state_dict(self):
211285 cap_logger .out .startswith ("Unsupported keys found in state dict when trying to load normalization layers" )
212286 )
213287
214- # flux control lora specific
215288 def test_lora_parameter_expanded_shapes (self ):
216289 components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
217290 pipe = self .pipeline_class (** components )
218291 pipe = pipe .to (torch_device )
219292 pipe .set_progress_bar_config (disable = None )
220293
221294 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
295+ original_out = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
222296
223297 logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
224298 logger .setLevel (logging .DEBUG )
225299
300+ # Change the transformer config to mimic a real use case.
301+ num_channels_without_control = 4
302+ transformer = FluxTransformer2DModel .from_config (
303+ components ["transformer" ].config , in_channels = num_channels_without_control
304+ ).to (torch_device )
305+ self .assertTrue (
306+ transformer .config .in_channels == num_channels_without_control ,
307+ f"Expected { num_channels_without_control } channels in the modified transformer but has { transformer .config .in_channels = } " ,
308+ )
309+
310+ original_transformer_state_dict = pipe .transformer .state_dict ()
311+ x_embedder_weight = original_transformer_state_dict .pop ("x_embedder.weight" )
312+ incompatible_keys = transformer .load_state_dict (original_transformer_state_dict , strict = False )
313+ self .assertTrue (
314+ "x_embedder.weight" in incompatible_keys .missing_keys ,
315+ "Could not find x_embedder.weight in the missing keys." ,
316+ )
317+ transformer .x_embedder .weight .data .copy_ (x_embedder_weight [..., :num_channels_without_control ])
318+ pipe .transformer = transformer
319+
226320 out_features , in_features = pipe .transformer .x_embedder .weight .shape
227321 rank = 4
228322
@@ -234,11 +328,13 @@ def test_lora_parameter_expanded_shapes(self):
234328 }
235329 with CaptureLogger (logger ) as cap_logger :
236330 pipe .load_lora_weights (lora_state_dict , "adapter-1" )
331+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
237332
333+ lora_out = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
334+
335+ self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
238336 self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == 2 * in_features )
239337 self .assertTrue (pipe .transformer .config .in_channels == 2 * in_features )
240-
241- pipe .delete_adapters ("adapter-1" )
242338 self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
243339
244340 components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
@@ -256,14 +352,20 @@ def test_lora_parameter_expanded_shapes(self):
256352 with self .assertRaises (NotImplementedError ):
257353 pipe .load_lora_weights (lora_state_dict , "adapter-1" )
258354
259- # flux control lora specific
260355 @require_peft_version_greater ("0.13.2" )
261356 def test_lora_B_bias (self ):
262357 components , _ , denoiser_lora_config = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
263358 pipe = self .pipeline_class (** components )
264359 pipe = pipe .to (torch_device )
265360 pipe .set_progress_bar_config (disable = None )
266361
362+ # keep track of the bias values of the base layers to perform checks later.
363+ bias_values = {}
364+ for name , module in pipe .transformer .named_modules ():
365+ if any (k in name for k in ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
366+ if module .bias is not None :
367+ bias_values [name ] = module .bias .data .clone ()
368+
267369 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
268370
269371 logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
0 commit comments