Skip to content

Commit 6ce2307

Browse files
committed
Merge branch 'flux-control-lora' into flux-control-lora-training-script
2 parents 9a83eff + 3204627 commit 6ce2307

File tree

4 files changed

+126
-29
lines changed

4 files changed

+126
-29
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,10 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
673673
inner_dim = 3072
674674
mlp_ratio = 4.0
675675

676+
for k in original_state_dict:
677+
if "bias" in k and "img_in" in k:
678+
print(f"{k=}")
679+
676680
def swap_scale_shift(weight):
677681
shift, scale = weight.chunk(2, dim=0)
678682
new_weight = torch.cat([scale, shift], dim=0)
@@ -750,7 +754,7 @@ def swap_scale_shift(weight):
750754
for i in range(num_layers):
751755
block_prefix = f"transformer_blocks.{i}."
752756

753-
for lora_key, lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]):
757+
for lora_key in ["lora_A", "lora_B"]:
754758
# norms
755759
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
756760
f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"

src/diffusers/loaders/lora_pipeline.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def load_lora_into_text_encoder(
427427
if lora_config_kwargs["lora_bias"]:
428428
if is_peft_version("<=", "0.13.2"):
429429
raise ValueError(
430-
"You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
430+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
431431
)
432432
else:
433433
if is_peft_version("<=", "0.13.2"):
@@ -970,7 +970,7 @@ def load_lora_into_text_encoder(
970970
if lora_config_kwargs["lora_bias"]:
971971
if is_peft_version("<=", "0.13.2"):
972972
raise ValueError(
973-
"You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
973+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
974974
)
975975
else:
976976
if is_peft_version("<=", "0.13.2"):
@@ -1479,7 +1479,7 @@ def load_lora_into_text_encoder(
14791479
if lora_config_kwargs["lora_bias"]:
14801480
if is_peft_version("<=", "0.13.2"):
14811481
raise ValueError(
1482-
"You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
1482+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
14831483
)
14841484
else:
14851485
if is_peft_version("<=", "0.13.2"):
@@ -2108,7 +2108,7 @@ def load_lora_into_text_encoder(
21082108
if lora_config_kwargs["lora_bias"]:
21092109
if is_peft_version("<=", "0.13.2"):
21102110
raise ValueError(
2111-
"You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
2111+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
21122112
)
21132113
else:
21142114
if is_peft_version("<=", "0.13.2"):
@@ -2246,7 +2246,7 @@ def fuse_lora(
22462246
):
22472247
logger.info(
22482248
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
2249-
'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
2249+
"as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
22502250
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
22512251
)
22522252

@@ -2318,14 +2318,13 @@ def _maybe_expand_transformer_param_shape_or_error_(
23182318

23192319
lora_A_weight_name = f"{name}.lora_A.weight"
23202320
lora_B_weight_name = f"{name}.lora_B.weight"
2321-
lora_B_bias_name = f"{name}.lora_B.bias"
2322-
23232321
if lora_A_weight_name not in state_dict.keys():
23242322
continue
23252323

23262324
in_features = state_dict[lora_A_weight_name].shape[1]
23272325
out_features = state_dict[lora_B_weight_name].shape[0]
23282326

2327+
# This means there's no need for an expansion in the params, so we simply skip.
23292328
if tuple(module_weight.shape) == (out_features, in_features):
23302329
continue
23312330

@@ -2349,27 +2348,19 @@ def _maybe_expand_transformer_param_shape_or_error_(
23492348
parent_module_name, _, current_module_name = name.rpartition(".")
23502349
parent_module = transformer.get_submodule(parent_module_name)
23512350

2351+
# TODO: consider initializing this under meta device for optims.
23522352
expanded_module = torch.nn.Linear(
23532353
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
23542354
)
2355-
2355+
# Only weights are expanded and biases are not.
23562356
new_weight = torch.zeros_like(
23572357
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
23582358
)
23592359
slices = tuple(slice(0, dim) for dim in module_weight.shape)
23602360
new_weight[slices] = module_weight
23612361
expanded_module.weight.data.copy_(new_weight)
2362-
2363-
bias_present_for_lora_B = lora_B_bias_name in state_dict
2364-
if bias_present_for_lora_B:
2365-
new_bias_shape = state_dict[lora_B_bias_name].shape
2366-
if bias and module_bias.shape < new_bias_shape:
2367-
new_bias = torch.zeros_like(
2368-
expanded_module.bias.data, device=module_bias.device, dtype=module_bias.dtype
2369-
)
2370-
slices = tuple(slice(0, dim) for dim in module_bias.shape)
2371-
new_bias[slices] = module_bias
2372-
expanded_module.bias.data.copy_(new_bias)
2362+
if module_bias is not None:
2363+
expanded_module.bias.data.copy_(module_bias)
23732364

23742365
setattr(parent_module, current_module_name, expanded_module)
23752366

@@ -2551,7 +2542,7 @@ def load_lora_into_text_encoder(
25512542
if lora_config_kwargs["lora_bias"]:
25522543
if is_peft_version("<=", "0.13.2"):
25532544
raise ValueError(
2554-
"You need `peft` 0.13.3 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
2545+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
25552546
)
25562547
else:
25572548
if is_peft_version("<=", "0.13.2"):

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
293293
if lora_config_kwargs["lora_bias"]:
294294
if is_peft_version("<=", "0.13.2"):
295295
raise ValueError(
296-
"You need `peft` 0.13.3 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
296+
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
297297
)
298298
else:
299299
if is_peft_version("<=", "0.13.2"):

tests/lora/test_lora_layers_flux.py

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import numpy as np
2222
import safetensors.torch
2323
import torch
24+
from PIL import Image
2425
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
2526

26-
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
27+
from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel
2728
from diffusers.utils import logging
2829
from diffusers.utils.testing_utils import (
2930
CaptureLogger,
@@ -159,7 +160,80 @@ def test_with_alpha_in_state_dict(self):
159160
)
160161
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
161162

162-
# flux control lora specific
163+
@unittest.skip("Not supported in Flux.")
164+
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
165+
pass
166+
167+
@unittest.skip("Not supported in Flux.")
168+
def test_modify_padding_mode(self):
169+
pass
170+
171+
172+
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
173+
pipeline_class = FluxControlPipeline
174+
scheduler_cls = FlowMatchEulerDiscreteScheduler()
175+
scheduler_kwargs = {}
176+
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
177+
transformer_kwargs = {
178+
"patch_size": 1,
179+
"in_channels": 8,
180+
"out_channels": 4,
181+
"num_layers": 1,
182+
"num_single_layers": 1,
183+
"attention_head_dim": 16,
184+
"num_attention_heads": 2,
185+
"joint_attention_dim": 32,
186+
"pooled_projection_dim": 32,
187+
"axes_dims_rope": [4, 4, 8],
188+
}
189+
transformer_cls = FluxTransformer2DModel
190+
vae_kwargs = {
191+
"sample_size": 32,
192+
"in_channels": 3,
193+
"out_channels": 3,
194+
"block_out_channels": (4,),
195+
"layers_per_block": 1,
196+
"latent_channels": 1,
197+
"norm_num_groups": 1,
198+
"use_quant_conv": False,
199+
"use_post_quant_conv": False,
200+
"shift_factor": 0.0609,
201+
"scaling_factor": 1.5035,
202+
}
203+
has_two_text_encoders = True
204+
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
205+
tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
206+
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
207+
text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
208+
209+
@property
210+
def output_shape(self):
211+
return (1, 8, 8, 3)
212+
213+
def get_dummy_inputs(self, with_generator=True):
214+
batch_size = 1
215+
sequence_length = 10
216+
num_channels = 4
217+
sizes = (32, 32)
218+
219+
generator = torch.manual_seed(0)
220+
noise = floats_tensor((batch_size, num_channels) + sizes)
221+
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
222+
223+
pipeline_inputs = {
224+
"prompt": "A painting of a squirrel eating a burger",
225+
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),
226+
"num_inference_steps": 4,
227+
"guidance_scale": 0.0,
228+
"height": 8,
229+
"width": 8,
230+
"output_type": "np",
231+
}
232+
if with_generator:
233+
pipeline_inputs.update({"generator": generator})
234+
235+
return noise, input_ids, pipeline_inputs
236+
163237
def test_with_norm_in_state_dict(self):
164238
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
165239
pipe = self.pipeline_class(**components)
@@ -184,7 +258,7 @@ def test_with_norm_in_state_dict(self):
184258

185259
with CaptureLogger(logger) as cap_logger:
186260
pipe.load_lora_weights(norm_state_dict)
187-
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
261+
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
188262

189263
self.assertTrue(
190264
cap_logger.out.startswith(
@@ -211,18 +285,38 @@ def test_with_norm_in_state_dict(self):
211285
cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers")
212286
)
213287

214-
# flux control lora specific
215288
def test_lora_parameter_expanded_shapes(self):
216289
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
217290
pipe = self.pipeline_class(**components)
218291
pipe = pipe.to(torch_device)
219292
pipe.set_progress_bar_config(disable=None)
220293

221294
_, _, inputs = self.get_dummy_inputs(with_generator=False)
295+
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
222296

223297
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
224298
logger.setLevel(logging.DEBUG)
225299

300+
# Change the transformer config to mimic a real use case.
301+
num_channels_without_control = 4
302+
transformer = FluxTransformer2DModel.from_config(
303+
components["transformer"].config, in_channels=num_channels_without_control
304+
).to(torch_device)
305+
self.assertTrue(
306+
transformer.config.in_channels == num_channels_without_control,
307+
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
308+
)
309+
310+
original_transformer_state_dict = pipe.transformer.state_dict()
311+
x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
312+
incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False)
313+
self.assertTrue(
314+
"x_embedder.weight" in incompatible_keys.missing_keys,
315+
"Could not find x_embedder.weight in the missing keys.",
316+
)
317+
transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
318+
pipe.transformer = transformer
319+
226320
out_features, in_features = pipe.transformer.x_embedder.weight.shape
227321
rank = 4
228322

@@ -234,11 +328,13 @@ def test_lora_parameter_expanded_shapes(self):
234328
}
235329
with CaptureLogger(logger) as cap_logger:
236330
pipe.load_lora_weights(lora_state_dict, "adapter-1")
331+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
237332

333+
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
334+
335+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
238336
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
239337
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
240-
241-
pipe.delete_adapters("adapter-1")
242338
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
243339

244340
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -256,14 +352,20 @@ def test_lora_parameter_expanded_shapes(self):
256352
with self.assertRaises(NotImplementedError):
257353
pipe.load_lora_weights(lora_state_dict, "adapter-1")
258354

259-
# flux control lora specific
260355
@require_peft_version_greater("0.13.2")
261356
def test_lora_B_bias(self):
262357
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
263358
pipe = self.pipeline_class(**components)
264359
pipe = pipe.to(torch_device)
265360
pipe.set_progress_bar_config(disable=None)
266361

362+
# keep track of the bias values of the base layers to perform checks later.
363+
bias_values = {}
364+
for name, module in pipe.transformer.named_modules():
365+
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
366+
if module.bias is not None:
367+
bias_values[name] = module.bias.data.clone()
368+
267369
_, _, inputs = self.get_dummy_inputs(with_generator=False)
268370

269371
logger = logging.get_logger("diffusers.loaders.lora_pipeline")

0 commit comments

Comments
 (0)