Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions examples/controlnet/train_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import copy
import functools
import gc
import logging
import math
import os
Expand Down Expand Up @@ -52,6 +53,7 @@
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.testing_utils import backend_empty_cache
from diffusers.utils.torch_utils import is_compiled_module


Expand All @@ -74,8 +76,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v

pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=controlnet,
controlnet=None,
safety_checker=None,
transformer=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
Expand All @@ -102,18 +105,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)

with torch.no_grad():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(
validation_prompts,
prompt_2=None,
prompt_3=None,
Comment on lines +116 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker for this PR but looks like prompt_2 and prompt_3 should be made Optional in the pipeline.

)

del pipeline
gc.collect()
backend_empty_cache(accelerator.device.type)

pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=controlnet,
safety_checker=None,
text_encoder=None,
text_encoder_2=None,
text_encoder_3=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.enable_model_cpu_offload(device=accelerator.device.type)
pipeline.set_progress_bar_config(disable=True)

image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
for i, validation_image in enumerate(validation_images):
validation_image = Image.open(validation_image).convert("RGB")
validation_prompt = validation_prompts[i]

images = []

for _ in range(args.num_validation_images):
with inference_ctx:
image = pipeline(
validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
prompt_embeds=prompt_embeds[i].unsqueeze(0),
negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0),
pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0),
control_image=validation_image,
num_inference_steps=20,
generator=generator,
).images[0]

images.append(image)
Expand Down Expand Up @@ -655,6 +695,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
trust_remote_code=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, could be an argument but more convenient like this especially as the example dataset requires it. Can be replicated across training scripts.
For future reference we should look at the num_proc option which should help speed up processing.

)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
Expand Down
Loading