diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index ffe460d72de8..08341d9c227c 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -17,6 +17,7 @@ import contextlib import copy import functools +import gc import logging import math import os @@ -52,6 +53,7 @@ from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.testing_utils import backend_empty_cache from diffusers.utils.torch_utils import is_compiled_module @@ -74,8 +76,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v pipeline = StableDiffusion3ControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, - controlnet=controlnet, + controlnet=None, safety_checker=None, + transformer=None, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -102,18 +105,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" ) + with torch.no_grad(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = pipeline.encode_prompt( + validation_prompts, + prompt_2=None, + prompt_3=None, + ) + + del pipeline + gc.collect() + backend_empty_cache(accelerator.device.type) + + pipeline = StableDiffusion3ControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + controlnet=controlnet, + safety_checker=None, + text_encoder=None, + text_encoder_2=None, + text_encoder_3=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline.enable_model_cpu_offload(device=accelerator.device.type) + pipeline.set_progress_bar_config(disable=True) + image_logs = [] inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type) - for validation_prompt, validation_image in zip(validation_prompts, validation_images): + for i, validation_image in enumerate(validation_images): validation_image = Image.open(validation_image).convert("RGB") + validation_prompt = validation_prompts[i] images = [] for _ in range(args.num_validation_images): with inference_ctx: image = pipeline( - validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator + prompt_embeds=prompt_embeds[i].unsqueeze(0), + negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0), + pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0), + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0), + control_image=validation_image, + num_inference_steps=20, + generator=generator, ).images[0] images.append(image) @@ -655,6 +695,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce dataset = load_dataset( args.train_data_dir, cache_dir=args.cache_dir, + trust_remote_code=True, ) # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script