Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions examples/controlnet/train_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import copy
import functools
import gc
import logging
import math
import os
Expand Down Expand Up @@ -74,8 +75,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v

pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=controlnet,
controlnet=None,
safety_checker=None,
transformer=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
Expand All @@ -102,18 +104,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)

with torch.no_grad():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(
validation_prompts,
prompt_2=None,
prompt_3=None,
Comment on lines +116 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker for this PR but looks like prompt_2 and prompt_3 should be made Optional in the pipeline.

)

del pipeline
gc.collect()
torch.cuda.empty_cache()

pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=controlnet,
safety_checker=None,
text_encoder=None,
text_encoder_2=None,
text_encoder_3=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=True)

image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
for i, validation_image in enumerate(validation_images):
validation_image = Image.open(validation_image).convert("RGB")
validation_prompt = validation_prompts[i]

images = []

for _ in range(args.num_validation_images):
with inference_ctx:
image = pipeline(
validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
prompt_embeds=prompt_embeds[i].unsqueeze(0),
negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0),
pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0),
control_image=validation_image,
num_inference_steps=20,
generator=generator,
).images[0]

images.append(image)
Expand Down Expand Up @@ -655,6 +694,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
trust_remote_code=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, could be an argument but more convenient like this especially as the example dataset requires it. Can be replicated across training scripts.
For future reference we should look at the num_proc option which should help speed up processing.

)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
Expand Down
Loading