Skip to content

Commit 55058e2

Browse files
committed
tests
1 parent ecbc4cb commit 55058e2

File tree

1 file changed

+109
-7
lines changed

1 file changed

+109
-7
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import numpy as np
2222
import safetensors.torch
2323
import torch
24+
from PIL import Image
2425
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
2526

26-
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
27+
from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel
2728
from diffusers.utils import logging
2829
from diffusers.utils.testing_utils import (
2930
CaptureLogger,
@@ -159,7 +160,80 @@ def test_with_alpha_in_state_dict(self):
159160
)
160161
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
161162

162-
# flux control lora specific
163+
@unittest.skip("Not supported in Flux.")
164+
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
165+
pass
166+
167+
@unittest.skip("Not supported in Flux.")
168+
def test_modify_padding_mode(self):
169+
pass
170+
171+
172+
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
173+
pipeline_class = FluxControlPipeline
174+
scheduler_cls = FlowMatchEulerDiscreteScheduler()
175+
scheduler_kwargs = {}
176+
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
177+
transformer_kwargs = {
178+
"patch_size": 1,
179+
"in_channels": 8,
180+
"out_channels": 4,
181+
"num_layers": 1,
182+
"num_single_layers": 1,
183+
"attention_head_dim": 16,
184+
"num_attention_heads": 2,
185+
"joint_attention_dim": 32,
186+
"pooled_projection_dim": 32,
187+
"axes_dims_rope": [4, 4, 8],
188+
}
189+
transformer_cls = FluxTransformer2DModel
190+
vae_kwargs = {
191+
"sample_size": 32,
192+
"in_channels": 3,
193+
"out_channels": 3,
194+
"block_out_channels": (4,),
195+
"layers_per_block": 1,
196+
"latent_channels": 1,
197+
"norm_num_groups": 1,
198+
"use_quant_conv": False,
199+
"use_post_quant_conv": False,
200+
"shift_factor": 0.0609,
201+
"scaling_factor": 1.5035,
202+
}
203+
has_two_text_encoders = True
204+
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
205+
tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
206+
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
207+
text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
208+
209+
@property
210+
def output_shape(self):
211+
return (1, 8, 8, 3)
212+
213+
def get_dummy_inputs(self, with_generator=True):
214+
batch_size = 1
215+
sequence_length = 10
216+
num_channels = 4
217+
sizes = (32, 32)
218+
219+
generator = torch.manual_seed(0)
220+
noise = floats_tensor((batch_size, num_channels) + sizes)
221+
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
222+
223+
pipeline_inputs = {
224+
"prompt": "A painting of a squirrel eating a burger",
225+
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),
226+
"num_inference_steps": 4,
227+
"guidance_scale": 0.0,
228+
"height": 8,
229+
"width": 8,
230+
"output_type": "np",
231+
}
232+
if with_generator:
233+
pipeline_inputs.update({"generator": generator})
234+
235+
return noise, input_ids, pipeline_inputs
236+
163237
def test_with_norm_in_state_dict(self):
164238
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
165239
pipe = self.pipeline_class(**components)
@@ -184,7 +258,7 @@ def test_with_norm_in_state_dict(self):
184258

185259
with CaptureLogger(logger) as cap_logger:
186260
pipe.load_lora_weights(norm_state_dict)
187-
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
261+
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
188262

189263
self.assertTrue(
190264
cap_logger.out.startswith(
@@ -211,18 +285,38 @@ def test_with_norm_in_state_dict(self):
211285
cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers")
212286
)
213287

214-
# flux control lora specific
215288
def test_lora_parameter_expanded_shapes(self):
216289
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
217290
pipe = self.pipeline_class(**components)
218291
pipe = pipe.to(torch_device)
219292
pipe.set_progress_bar_config(disable=None)
220293

221294
_, _, inputs = self.get_dummy_inputs(with_generator=False)
295+
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
222296

223297
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
224298
logger.setLevel(logging.DEBUG)
225299

300+
# Change the transformer config to mimic a real use case.
301+
num_channels_without_control = 4
302+
transformer = FluxTransformer2DModel.from_config(
303+
components["transformer"].config, in_channels=num_channels_without_control
304+
).to(torch_device)
305+
self.assertTrue(
306+
transformer.config.in_channels == num_channels_without_control,
307+
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
308+
)
309+
310+
original_transformer_state_dict = pipe.transformer.state_dict()
311+
x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
312+
incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False)
313+
self.assertTrue(
314+
"x_embedder.weight" in incompatible_keys.missing_keys,
315+
"Could not find x_embedder.weight in the missing keys.",
316+
)
317+
transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
318+
pipe.transformer = transformer
319+
226320
out_features, in_features = pipe.transformer.x_embedder.weight.shape
227321
rank = 4
228322

@@ -234,11 +328,13 @@ def test_lora_parameter_expanded_shapes(self):
234328
}
235329
with CaptureLogger(logger) as cap_logger:
236330
pipe.load_lora_weights(lora_state_dict, "adapter-1")
331+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
237332

333+
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
334+
335+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
238336
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
239337
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
240-
241-
pipe.delete_adapters("adapter-1")
242338
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
243339

244340
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -256,14 +352,20 @@ def test_lora_parameter_expanded_shapes(self):
256352
with self.assertRaises(NotImplementedError):
257353
pipe.load_lora_weights(lora_state_dict, "adapter-1")
258354

259-
# flux control lora specific
260355
@require_peft_version_greater("0.13.2")
261356
def test_lora_B_bias(self):
262357
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
263358
pipe = self.pipeline_class(**components)
264359
pipe = pipe.to(torch_device)
265360
pipe.set_progress_bar_config(disable=None)
266361

362+
# keep track of the bias values of the base layers to perform checks later.
363+
bias_values = {}
364+
for name, module in pipe.transformer.named_modules():
365+
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
366+
if module.bias is not None:
367+
bias_values[name] = module.bias.data.clone()
368+
267369
_, _, inputs = self.get_dummy_inputs(with_generator=False)
268370

269371
logger = logging.get_logger("diffusers.loaders.lora_pipeline")

0 commit comments

Comments
 (0)