Skip to content

Commit 90e9517

Browse files
committed
log_validation with mixed precision
1 parent 364f478 commit 90e9517

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,19 @@ def log_validation(
228228

229229
# run inference
230230
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
231-
autocast_ctx = nullcontext()
232-
233-
with autocast_ctx:
234-
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
231+
autocast_ctx = torch.autocast(accelerator.device.type)
232+
233+
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
234+
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
235+
)
236+
images = []
237+
for _ in range(args.num_validation_images):
238+
with autocast_ctx:
239+
image = pipeline(
240+
prompt_embeds=prompt_embeds,
241+
pooled_prompt_embeds=pooled_prompt_embeds,
242+
generator=generator).images[0]
243+
images.append(image)
235244

236245
for tracker in accelerator.trackers:
237246
phase_name = "test" if is_final_validation else "validation"

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,24 @@ def log_validation(
177177
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
178178
f" {args.validation_prompt}."
179179
)
180-
pipeline = pipeline.to(accelerator.device)
180+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
181181
pipeline.set_progress_bar_config(disable=True)
182182

183183
# run inference
184184
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
185-
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
186-
autocast_ctx = nullcontext()
185+
autocast_ctx = torch.autocast(accelerator.device.type)
187186

188-
with autocast_ctx:
189-
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
187+
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
188+
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
189+
)
190+
images = []
191+
for _ in range(args.num_validation_images):
192+
with autocast_ctx:
193+
image = pipeline(
194+
prompt_embeds=prompt_embeds,
195+
pooled_prompt_embeds=pooled_prompt_embeds,
196+
generator=generator).images[0]
197+
images.append(image)
190198

191199
for tracker in accelerator.trackers:
192200
phase_name = "test" if is_final_validation else "validation"
@@ -203,8 +211,7 @@ def log_validation(
203211
)
204212

205213
del pipeline
206-
if torch.cuda.is_available():
207-
torch.cuda.empty_cache()
214+
free_memory()
208215

209216
return images
210217

0 commit comments

Comments
 (0)