Skip to content

Commit 7b82bdc

Browse files
authored
Merge branch 'main' into Sana
2 parents ea7878c + 22c4f07 commit 7b82bdc

File tree

6 files changed

+166
-11
lines changed

6 files changed

+166
-11
lines changed

examples/flux-control/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ accelerate launch train_control_lora_flux.py \
3636
--max_train_steps=5000 \
3737
--validation_image="openpose.png" \
3838
--validation_prompt="A couple, 4k photo, highly detailed" \
39+
--offload \
3940
--seed="0" \
4041
--push_to_hub
4142
```
@@ -154,6 +155,7 @@ accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
154155
--validation_steps=200 \
155156
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
156157
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
158+
--offload \
157159
--seed="0" \
158160
--push_to_hub
159161
```

examples/flux-control/train_control_flux.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,11 @@ def parse_args(input_args=None):
541541
default=1.29,
542542
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
543543
)
544+
parser.add_argument(
545+
"--offload",
546+
action="store_true",
547+
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
548+
)
544549

545550
if input_args is not None:
546551
args = parser.parse_args(input_args)
@@ -999,8 +1004,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
9991004
control_latents = encode_images(
10001005
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
10011006
)
1002-
# offload vae to CPU.
1003-
vae.cpu()
1007+
if args.offload:
1008+
# offload vae to CPU.
1009+
vae.cpu()
10041010

10051011
# Sample a random timestep for each image
10061012
# for weighting schemes where we sample timesteps non-uniformly
@@ -1064,7 +1070,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10641070
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
10651071
prompt_embeds.zero_()
10661072
pooled_prompt_embeds.zero_()
1067-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1073+
if args.offload:
1074+
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
10681075

10691076
# Predict.
10701077
model_pred = flux_transformer(

examples/flux-control/train_control_lora_flux.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,11 @@ def parse_args(input_args=None):
573573
default=1.29,
574574
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
575575
)
576+
parser.add_argument(
577+
"--offload",
578+
action="store_true",
579+
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
580+
)
576581

577582
if input_args is not None:
578583
args = parser.parse_args(input_args)
@@ -1140,8 +1145,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11401145
control_latents = encode_images(
11411146
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
11421147
)
1143-
# offload vae to CPU.
1144-
vae.cpu()
1148+
1149+
if args.offload:
1150+
# offload vae to CPU.
1151+
vae.cpu()
11451152

11461153
# Sample a random timestep for each image
11471154
# for weighting schemes where we sample timesteps non-uniformly
@@ -1205,7 +1212,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12051212
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
12061213
prompt_embeds.zero_()
12071214
pooled_prompt_embeds.zero_()
1208-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1215+
if args.offload:
1216+
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
12091217

12101218
# Predict.
12111219
model_pred = flux_transformer(

src/diffusers/loaders/lora_pipeline.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,12 +2337,19 @@ def _maybe_expand_transformer_param_shape_or_error_(
23372337
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
23382338
)
23392339

2340-
logger.debug(
2340+
debug_message = (
23412341
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
23422342
f"checkpoint contains higher number of features than expected. The number of input_features will be "
2343-
f"expanded from {module_in_features} to {in_features}, and the number of output features will be "
2344-
f"expanded from {module_out_features} to {out_features}."
2343+
f"expanded from {module_in_features} to {in_features}"
23452344
)
2345+
if module_out_features != out_features:
2346+
debug_message += (
2347+
", and the number of output features will be "
2348+
f"expanded from {module_out_features} to {out_features}."
2349+
)
2350+
else:
2351+
debug_message += "."
2352+
logger.debug(debug_message)
23462353

23472354
has_param_with_shape_update = True
23482355
parent_module_name, _, current_module_name = name.rpartition(".")

src/diffusers/loaders/peft.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
205205
weights.
206206
"""
207207
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
208+
from peft.tuners.tuners_utils import BaseTunerLayer
208209

209210
cache_dir = kwargs.pop("cache_dir", None)
210211
force_download = kwargs.pop("force_download", False)
@@ -316,8 +317,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
316317
if is_peft_version(">=", "0.13.1"):
317318
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
318319

319-
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
320-
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
320+
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
321+
# we should also delete the `peft_config` associated to the `adapter_name`.
322+
try:
323+
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
324+
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
325+
except RuntimeError as e:
326+
for module in self.modules():
327+
if isinstance(module, BaseTunerLayer):
328+
active_adapters = module.active_adapters
329+
for active_adapter in active_adapters:
330+
if adapter_name in active_adapter:
331+
module.delete_adapter(adapter_name)
332+
333+
self.peft_config.pop(adapter_name)
334+
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
335+
raise
321336

322337
warn_msg = ""
323338
if incompatible_keys is not None:

tests/lora/test_lora_layers_flux.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,122 @@ def test_correct_lora_configs_with_different_ranks(self):
430430
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
431431
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
432432

433+
def test_lora_expanding_shape_with_normal_lora_raises_error(self):
434+
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
435+
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
436+
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
437+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
438+
439+
# Change the transformer config to mimic a real use case.
440+
num_channels_without_control = 4
441+
transformer = FluxTransformer2DModel.from_config(
442+
components["transformer"].config, in_channels=num_channels_without_control
443+
).to(torch_device)
444+
components["transformer"] = transformer
445+
446+
pipe = self.pipeline_class(**components)
447+
pipe = pipe.to(torch_device)
448+
pipe.set_progress_bar_config(disable=None)
449+
450+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
451+
logger.setLevel(logging.DEBUG)
452+
453+
out_features, in_features = pipe.transformer.x_embedder.weight.shape
454+
rank = 4
455+
456+
shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
457+
shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
458+
lora_state_dict = {
459+
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
460+
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
461+
}
462+
with CaptureLogger(logger) as cap_logger:
463+
pipe.load_lora_weights(lora_state_dict, "adapter-1")
464+
465+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
466+
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
467+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
468+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
469+
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
470+
471+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
472+
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
473+
474+
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
475+
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
476+
lora_state_dict = {
477+
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
478+
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
479+
}
480+
481+
# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
482+
# input features before expansion. This should raise an error about the weight shapes being incompatible.
483+
self.assertRaisesRegex(
484+
RuntimeError,
485+
"size mismatch for x_embedder.lora_A.adapter-2.weight",
486+
pipe.load_lora_weights,
487+
lora_state_dict,
488+
"adapter-2",
489+
)
490+
# We should have `adapter-1` as the only adapter.
491+
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
492+
493+
# Check if the output is the same after lora loading error
494+
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
495+
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))
496+
497+
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
498+
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
499+
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
500+
# weight is compatible with the current model inadequate. This should be addressed when attempting support for
501+
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
502+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
503+
# Change the transformer config to mimic a real use case.
504+
num_channels_without_control = 4
505+
transformer = FluxTransformer2DModel.from_config(
506+
components["transformer"].config, in_channels=num_channels_without_control
507+
).to(torch_device)
508+
components["transformer"] = transformer
509+
510+
pipe = self.pipeline_class(**components)
511+
pipe = pipe.to(torch_device)
512+
pipe.set_progress_bar_config(disable=None)
513+
514+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
515+
logger.setLevel(logging.DEBUG)
516+
517+
out_features, in_features = pipe.transformer.x_embedder.weight.shape
518+
rank = 4
519+
520+
lora_state_dict = {
521+
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
522+
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
523+
}
524+
525+
with CaptureLogger(logger) as cap_logger:
526+
pipe.load_lora_weights(lora_state_dict, "adapter-1")
527+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
528+
529+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
530+
self.assertTrue(pipe.transformer.config.in_channels == in_features)
531+
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
532+
533+
lora_state_dict = {
534+
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
535+
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
536+
}
537+
538+
# We should check for input shapes being incompatible here. But because above mentioned issue is
539+
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
540+
# mismatch error.
541+
self.assertRaisesRegex(
542+
RuntimeError,
543+
"size mismatch for x_embedder.lora_A.adapter-2.weight",
544+
pipe.load_lora_weights,
545+
lora_state_dict,
546+
"adapter-2",
547+
)
548+
433549
@unittest.skip("Not supported in Flux.")
434550
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
435551
pass

0 commit comments

Comments
 (0)