1717import contextlib
1818import copy
1919import functools
20+ import gc
2021import logging
2122import math
2223import os
@@ -74,8 +75,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
7475
7576 pipeline = StableDiffusion3ControlNetPipeline .from_pretrained (
7677 args .pretrained_model_name_or_path ,
77- controlnet = controlnet ,
78+ controlnet = None ,
7879 safety_checker = None ,
80+ transformer = None ,
7981 revision = args .revision ,
8082 variant = args .variant ,
8183 torch_dtype = weight_dtype ,
@@ -102,18 +104,55 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
102104 "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
103105 )
104106
107+ with torch .no_grad ():
108+ (
109+ prompt_embeds ,
110+ negative_prompt_embeds ,
111+ pooled_prompt_embeds ,
112+ negative_pooled_prompt_embeds ,
113+ ) = pipeline .encode_prompt (
114+ validation_prompts ,
115+ prompt_2 = None ,
116+ prompt_3 = None ,
117+ )
118+
119+ del pipeline
120+ gc .collect ()
121+ torch .cuda .empty_cache ()
122+
123+ pipeline = StableDiffusion3ControlNetPipeline .from_pretrained (
124+ args .pretrained_model_name_or_path ,
125+ controlnet = controlnet ,
126+ safety_checker = None ,
127+ text_encoder = None ,
128+ text_encoder_2 = None ,
129+ text_encoder_3 = None ,
130+ revision = args .revision ,
131+ variant = args .variant ,
132+ torch_dtype = weight_dtype ,
133+ )
134+ pipeline .enable_model_cpu_offload ()
135+ pipeline .set_progress_bar_config (disable = True )
136+
105137 image_logs = []
106138 inference_ctx = contextlib .nullcontext () if is_final_validation else torch .autocast (accelerator .device .type )
107139
108- for validation_prompt , validation_image in zip ( validation_prompts , validation_images ):
140+ for i , validation_image in enumerate ( validation_images ):
109141 validation_image = Image .open (validation_image ).convert ("RGB" )
142+ validation_prompt = validation_prompts [i ]
110143
111144 images = []
112145
113146 for _ in range (args .num_validation_images ):
114147 with inference_ctx :
115148 image = pipeline (
116- validation_prompt , control_image = validation_image , num_inference_steps = 20 , generator = generator
149+ prompt_embeds = prompt_embeds [i ].unsqueeze (0 ),
150+ negative_prompt_embeds = negative_prompt_embeds [i ].unsqueeze (0 ),
151+ pooled_prompt_embeds = pooled_prompt_embeds [i ].unsqueeze (0 ),
152+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds [i ].unsqueeze (0 ),
153+ control_image = validation_image ,
154+ num_inference_steps = 20 ,
155+ generator = generator ,
117156 ).images [0 ]
118157
119158 images .append (image )
@@ -655,6 +694,7 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
655694 dataset = load_dataset (
656695 args .train_data_dir ,
657696 cache_dir = args .cache_dir ,
697+ trust_remote_code = True ,
658698 )
659699 # See more about loading custom images at
660700 # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
0 commit comments