Skip to content

Commit 579bb5f

Browse files
committed
fix: SD3 ControlNet validation so that it runs on a A100.
1 parent f685981 commit 579bb5f

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

examples/controlnet/train_controlnet_sd3.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import contextlib
1818
import copy
1919
import functools
20+
import gc
2021
import logging
2122
import math
2223
import os
@@ -74,8 +75,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
7475

7576
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
7677
args.pretrained_model_name_or_path,
77-
controlnet=controlnet,
78+
controlnet=None,
7879
safety_checker=None,
80+
transformer=None,
7981
revision=args.revision,
8082
variant=args.variant,
8183
torch_dtype=weight_dtype,
@@ -102,18 +104,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
102104
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
103105
)
104106

107+
with torch.no_grad():
108+
(
109+
prompt_embeds,
110+
negative_prompt_embeds,
111+
pooled_prompt_embeds,
112+
negative_pooled_prompt_embeds,
113+
) = pipeline.encode_prompt(
114+
validation_prompts,
115+
prompt_2=None,
116+
prompt_3=None,
117+
)
118+
119+
del pipeline
120+
gc.collect()
121+
torch.cuda.empty_cache()
122+
123+
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
124+
args.pretrained_model_name_or_path,
125+
controlnet=controlnet,
126+
safety_checker=None,
127+
text_encoder=None,
128+
text_encoder_2=None,
129+
text_encoder_3=None,
130+
revision=args.revision,
131+
variant=args.variant,
132+
torch_dtype=weight_dtype,
133+
)
134+
pipeline.enable_model_cpu_offload()
135+
pipeline.set_progress_bar_config(disable=True)
136+
105137
image_logs = []
106138
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)
107139

108-
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
140+
for i, validation_image in enumerate(validation_images):
109141
validation_image = Image.open(validation_image).convert("RGB")
142+
validation_prompt = validation_prompts[i]
110143

111144
images = []
112145

113146
for _ in range(args.num_validation_images):
114147
with inference_ctx:
115148
image = pipeline(
116-
validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
149+
prompt_embeds=prompt_embeds[i].unsqueeze(0),
150+
negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0),
151+
pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0),
152+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0),
153+
control_image=validation_image,
154+
num_inference_steps=20,
155+
generator=generator,
117156
).images[0]
118157

119158
images.append(image)
@@ -655,6 +694,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
655694
dataset = load_dataset(
656695
args.train_data_dir,
657696
cache_dir=args.cache_dir,
697+
trust_remote_code=True,
658698
)
659699
# See more about loading custom images at
660700
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script

0 commit comments

Comments
 (0)