Skip to content

Commit 2d43094

Browse files
fix RuntimeError: Input type (float) and bias type (c10::Half) should be the same in train_text_to_image_lora.py (#6259)
* fix RuntimeError: Input type (float) and bias type (c10::Half) should be the same * format source code * format code * remove the autocast blocks within the pipeline * add autocast blocks to pipeline caller in train_text_to_image_lora.py
1 parent 7c05b97 commit 2d43094

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -847,10 +847,11 @@ def collate_fn(examples):
847847
if args.seed is not None:
848848
generator = generator.manual_seed(args.seed)
849849
images = []
850-
for _ in range(args.num_validation_images):
851-
images.append(
852-
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
853-
)
850+
with torch.cuda.amp.autocast():
851+
for _ in range(args.num_validation_images):
852+
images.append(
853+
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
854+
)
854855

855856
for tracker in accelerator.trackers:
856857
if tracker.name == "tensorboard":
@@ -916,8 +917,11 @@ def collate_fn(examples):
916917
if args.seed is not None:
917918
generator = generator.manual_seed(args.seed)
918919
images = []
919-
for _ in range(args.num_validation_images):
920-
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
920+
with torch.cuda.amp.autocast():
921+
for _ in range(args.num_validation_images):
922+
images.append(
923+
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
924+
)
921925

922926
for tracker in accelerator.trackers:
923927
if len(images) != 0:

0 commit comments

Comments
 (0)