@@ -874,68 +874,68 @@ def __call__(
874874 )
875875
876876 # 6. Denoising loop
877- with self .progress_bar (total = num_inference_steps ) as progress_bar :
878- for i , t in enumerate (timesteps ):
879- if self .interrupt :
880- continue
881-
882- if image_embeds is not None :
883- self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = image_embeds
884- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
885- timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
886-
887- noise_pred = self .transformer (
877+ #with self.progress_bar(total=num_inference_steps) as progress_bar:
878+ for i , t in enumerate (timesteps ):
879+ if self .interrupt :
880+ continue
881+
882+ if image_embeds is not None :
883+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = image_embeds
884+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
885+ timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
886+
887+ noise_pred = self .transformer (
888+ hidden_states = latents ,
889+ timestep = timestep / 1000 ,
890+ guidance = guidance ,
891+ pooled_projections = pooled_prompt_embeds ,
892+ encoder_hidden_states = prompt_embeds ,
893+ txt_ids = text_ids ,
894+ img_ids = latent_image_ids ,
895+ joint_attention_kwargs = self .joint_attention_kwargs ,
896+ return_dict = False ,
897+ )[0 ]
898+
899+ if do_true_cfg :
900+ if negative_image_embeds is not None :
901+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
902+ neg_noise_pred = self .transformer (
888903 hidden_states = latents ,
889904 timestep = timestep / 1000 ,
890905 guidance = guidance ,
891- pooled_projections = pooled_prompt_embeds ,
892- encoder_hidden_states = prompt_embeds ,
906+ pooled_projections = negative_pooled_prompt_embeds ,
907+ encoder_hidden_states = negative_prompt_embeds ,
893908 txt_ids = text_ids ,
894909 img_ids = latent_image_ids ,
895910 joint_attention_kwargs = self .joint_attention_kwargs ,
896911 return_dict = False ,
897912 )[0 ]
913+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred )
914+
915+ # compute the previous noisy sample x_t -> x_t-1
916+ latents_dtype = latents .dtype
917+ latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
918+
919+ if latents .dtype != latents_dtype :
920+ if torch .backends .mps .is_available ():
921+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
922+ latents = latents .to (latents_dtype )
923+
924+ if callback_on_step_end is not None :
925+ callback_kwargs = {}
926+ for k in callback_on_step_end_tensor_inputs :
927+ callback_kwargs [k ] = locals ()[k ]
928+ callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
929+
930+ latents = callback_outputs .pop ("latents" , latents )
931+ prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
932+
933+ # call the callback, if provided
934+ # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
935+ # progress_bar.update()
898936
899- if do_true_cfg :
900- if negative_image_embeds is not None :
901- self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
902- neg_noise_pred = self .transformer (
903- hidden_states = latents ,
904- timestep = timestep / 1000 ,
905- guidance = guidance ,
906- pooled_projections = negative_pooled_prompt_embeds ,
907- encoder_hidden_states = negative_prompt_embeds ,
908- txt_ids = text_ids ,
909- img_ids = latent_image_ids ,
910- joint_attention_kwargs = self .joint_attention_kwargs ,
911- return_dict = False ,
912- )[0 ]
913- noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred )
914-
915- # compute the previous noisy sample x_t -> x_t-1
916- latents_dtype = latents .dtype
917- latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
918-
919- if latents .dtype != latents_dtype :
920- if torch .backends .mps .is_available ():
921- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
922- latents = latents .to (latents_dtype )
923-
924- if callback_on_step_end is not None :
925- callback_kwargs = {}
926- for k in callback_on_step_end_tensor_inputs :
927- callback_kwargs [k ] = locals ()[k ]
928- callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
929-
930- latents = callback_outputs .pop ("latents" , latents )
931- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
932-
933- # call the callback, if provided
934- if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
935- progress_bar .update ()
936-
937- if XLA_AVAILABLE :
938- xm .mark_step ()
937+ if XLA_AVAILABLE :
938+ xm .mark_step ()
939939
940940 if output_type == "latent" :
941941 image = latents
0 commit comments