Skip to content

Commit 20cb3a8

Browse files
committed
enable dreambooth_lora on other devices
Signed-off-by: jiqing-feng <[email protected]>
1 parent aeac0a0 commit 20cb3a8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,14 @@ def log_validation(
151151
if args.validation_images is None:
152152
images = []
153153
for _ in range(args.num_validation_images):
154-
with torch.cuda.amp.autocast():
154+
with torch.amp.autocast(pipeline.device.type):
155155
image = pipeline(**pipeline_args, generator=generator).images[0]
156156
images.append(image)
157157
else:
158158
images = []
159159
for image in args.validation_images:
160160
image = Image.open(image)
161-
with torch.cuda.amp.autocast():
161+
with torch.amp.autocast(pipeline.device.type):
162162
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
163163
images.append(image)
164164

0 commit comments

Comments
 (0)