@@ -1080,7 +1080,7 @@ def __call__(
10801080                        tile_latents  =  latents [:, :, px_row_init :px_row_end , px_col_init :px_col_end ]
10811081                        # expand the latents if we are doing classifier free guidance 
10821082                        latent_model_input  =  (
1083-                             torch .cat ([tile_latents ] *  2 ) if  self .do_classifier_free_guidance  else  latents 
1083+                             torch .cat ([tile_latents ] *  2 ) if  self .do_classifier_free_guidance  else  tile_latents 
10841084                        )
10851085                        latent_model_input  =  self .scheduler .scale_model_input (latent_model_input , t )
10861086
@@ -1089,15 +1089,15 @@ def __call__(
10891089                            "text_embeds" : embeddings_and_added_time [row ][col ][1 ],
10901090                            "time_ids" : embeddings_and_added_time [row ][col ][2 ],
10911091                        }
1092-                         with  torch .amp .autocast (device .type , dtype = dtype , enabled = dtype  !=  self .unet .dtype ):
1093-                              noise_pred  =  self .unet (
1094-                                  latent_model_input ,
1095-                                  t ,
1096-                                  encoder_hidden_states = embeddings_and_added_time [row ][col ][0 ],
1097-                                  cross_attention_kwargs = self .cross_attention_kwargs ,
1098-                                  added_cond_kwargs = added_cond_kwargs ,
1099-                                  return_dict = False ,
1100-                              )[0 ]
1092+                         # with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
1093+                         noise_pred  =  self .unet (
1094+                             latent_model_input ,
1095+                             t ,
1096+                             encoder_hidden_states = embeddings_and_added_time [row ][col ][0 ],
1097+                             cross_attention_kwargs = self .cross_attention_kwargs ,
1098+                             added_cond_kwargs = added_cond_kwargs ,
1099+                             return_dict = False ,
1100+                         )[0 ]
11011101
11021102                        # perform guidance 
11031103                        if  self .do_classifier_free_guidance :
0 commit comments