Skip to content

Commit 46a6b26

Browse files
committed
Fix fp16 bug
1 parent 6f74ef5 commit 46a6b26

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def log_validation(
203203

204204
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
205205

206-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
206+
pipeline = pipeline.to(accelerator.device)
207207
pipeline.set_progress_bar_config(disable=True)
208208

209209
# run inference
@@ -213,7 +213,7 @@ def log_validation(
213213
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
214214
autocast_ctx = nullcontext()
215215
else:
216-
autocast_ctx = torch.autocast(accelerator.device.type)
216+
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
217217

218218
with autocast_ctx:
219219
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

0 commit comments

Comments
 (0)