We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent aeac0a0 commit 20cb3a8Copy full SHA for 20cb3a8
examples/dreambooth/train_dreambooth_lora.py
@@ -151,14 +151,14 @@ def log_validation(
151
if args.validation_images is None:
152
images = []
153
for _ in range(args.num_validation_images):
154
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast(pipeline.device.type):
155
image = pipeline(**pipeline_args, generator=generator).images[0]
156
images.append(image)
157
else:
158
159
for image in args.validation_images:
160
image = Image.open(image)
161
162
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
163
164
0 commit comments