@@ -708,48 +708,48 @@ def __call__(
708708 guidance = None
709709
710710 # 6. Denoising loop
711- # with self.progress_bar(total=num_inference_steps) as progress_bar:
712- for i , t in enumerate (timesteps ):
713- if self .interrupt :
714- continue
715-
716- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
717- timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
718-
719- noise_pred = self .transformer (
720- hidden_states = latents ,
721- timestep = timestep / 1000 ,
722- guidance = guidance ,
723- pooled_projections = pooled_prompt_embeds ,
724- encoder_hidden_states = prompt_embeds ,
725- txt_ids = text_ids ,
726- img_ids = latent_image_ids ,
727- joint_attention_kwargs = self .joint_attention_kwargs ,
728- return_dict = False ,
729- )[0 ]
730-
731- # compute the previous noisy sample x_t -> x_t-1
732- latents_dtype = latents .dtype
733-
734- latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
735-
736- if latents .dtype != latents_dtype :
737- if torch .backends .mps .is_available ():
738- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
739- latents = latents .to (latents_dtype )
740-
741- if callback_on_step_end is not None :
742- callback_kwargs = {}
743- for k in callback_on_step_end_tensor_inputs :
744- callback_kwargs [k ] = locals ()[k ]
745- callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
746-
747- latents = callback_outputs .pop ("latents" , latents )
748- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
711+ with self .progress_bar (total = num_inference_steps ) as progress_bar :
712+ for i , t in enumerate (timesteps ):
713+ if self .interrupt :
714+ continue
715+
716+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
717+ timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
718+
719+ noise_pred = self .transformer (
720+ hidden_states = latents ,
721+ timestep = timestep / 1000 ,
722+ guidance = guidance ,
723+ pooled_projections = pooled_prompt_embeds ,
724+ encoder_hidden_states = prompt_embeds ,
725+ txt_ids = text_ids ,
726+ img_ids = latent_image_ids ,
727+ joint_attention_kwargs = self .joint_attention_kwargs ,
728+ return_dict = False ,
729+ )[0 ]
730+
731+ # compute the previous noisy sample x_t -> x_t-1
732+ latents_dtype = latents .dtype
733+
734+ latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
735+
736+ if latents .dtype != latents_dtype :
737+ if torch .backends .mps .is_available ():
738+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
739+ latents = latents .to (latents_dtype )
740+
741+ if callback_on_step_end is not None :
742+ callback_kwargs = {}
743+ for k in callback_on_step_end_tensor_inputs :
744+ callback_kwargs [k ] = locals ()[k ]
745+ callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
746+
747+ latents = callback_outputs .pop ("latents" , latents )
748+ prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
749749
750750 # call the callback, if provided
751- # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
752- # progress_bar.update()
751+ if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
752+ progress_bar .update ()
753753
754754 if XLA_AVAILABLE :
755755 xm .mark_step ()
0 commit comments