diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 35704c574f28..2b6a9809193d 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -203,7 +203,8 @@ def log_validation( pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + #pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) #, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference @@ -213,7 +214,8 @@ def log_validation( if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: autocast_ctx = nullcontext() else: - autocast_ctx = torch.autocast(accelerator.device.type) + #autocast_ctx = torch.autocast(accelerator.device.type) + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() with autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]