@@ -161,7 +161,7 @@ def log_validation(
161161 f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
162162 f" { args .validation_prompt } ."
163163 )
164- pipeline = pipeline .to (accelerator .device , dtype = torch_dtype )
164+ pipeline = pipeline .to (accelerator .device )
165165 pipeline .set_progress_bar_config (disable = True )
166166
167167 # run inference
@@ -1580,7 +1580,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15801580 )
15811581
15821582 # handle guidance
1583- if transformer .config .guidance_embeds :
1583+ if accelerator . unwrap_model ( transformer ) .config .guidance_embeds :
15841584 guidance = torch .tensor ([args .guidance_scale ], device = accelerator .device )
15851585 guidance = guidance .expand (model_input .shape [0 ])
15861586 else :
@@ -1694,6 +1694,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16941694 # create pipeline
16951695 if not args .train_text_encoder :
16961696 text_encoder_one , text_encoder_two = load_text_encoders (text_encoder_cls_one , text_encoder_cls_two )
1697+ text_encoder_one .to (weight_dtype )
1698+ text_encoder_two .to (weight_dtype )
16971699 else : # even when training the text encoder we're only training text encoder one
16981700 text_encoder_two = text_encoder_cls_two .from_pretrained (
16991701 args .pretrained_model_name_or_path ,
0 commit comments