Skip to content

Commit 9627cc1

Browse files
add tqdm again to flux pipeline.
1 parent de5faf8 commit 9627cc1

File tree

1 file changed

+40
-40
lines changed

1 file changed

+40
-40
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -708,48 +708,48 @@ def __call__(
708708
guidance = None
709709

710710
# 6. Denoising loop
711-
#with self.progress_bar(total=num_inference_steps) as progress_bar:
712-
for i, t in enumerate(timesteps):
713-
if self.interrupt:
714-
continue
715-
716-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
717-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
718-
719-
noise_pred = self.transformer(
720-
hidden_states=latents,
721-
timestep=timestep / 1000,
722-
guidance=guidance,
723-
pooled_projections=pooled_prompt_embeds,
724-
encoder_hidden_states=prompt_embeds,
725-
txt_ids=text_ids,
726-
img_ids=latent_image_ids,
727-
joint_attention_kwargs=self.joint_attention_kwargs,
728-
return_dict=False,
729-
)[0]
730-
731-
# compute the previous noisy sample x_t -> x_t-1
732-
latents_dtype = latents.dtype
733-
734-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
735-
736-
if latents.dtype != latents_dtype:
737-
if torch.backends.mps.is_available():
738-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
739-
latents = latents.to(latents_dtype)
740-
741-
if callback_on_step_end is not None:
742-
callback_kwargs = {}
743-
for k in callback_on_step_end_tensor_inputs:
744-
callback_kwargs[k] = locals()[k]
745-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
746-
747-
latents = callback_outputs.pop("latents", latents)
748-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
711+
with self.progress_bar(total=num_inference_steps) as progress_bar:
712+
for i, t in enumerate(timesteps):
713+
if self.interrupt:
714+
continue
715+
716+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
717+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
718+
719+
noise_pred = self.transformer(
720+
hidden_states=latents,
721+
timestep=timestep / 1000,
722+
guidance=guidance,
723+
pooled_projections=pooled_prompt_embeds,
724+
encoder_hidden_states=prompt_embeds,
725+
txt_ids=text_ids,
726+
img_ids=latent_image_ids,
727+
joint_attention_kwargs=self.joint_attention_kwargs,
728+
return_dict=False,
729+
)[0]
730+
731+
# compute the previous noisy sample x_t -> x_t-1
732+
latents_dtype = latents.dtype
733+
734+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
735+
736+
if latents.dtype != latents_dtype:
737+
if torch.backends.mps.is_available():
738+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
739+
latents = latents.to(latents_dtype)
740+
741+
if callback_on_step_end is not None:
742+
callback_kwargs = {}
743+
for k in callback_on_step_end_tensor_inputs:
744+
callback_kwargs[k] = locals()[k]
745+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
746+
747+
latents = callback_outputs.pop("latents", latents)
748+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
749749

750750
# call the callback, if provided
751-
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
752-
# progress_bar.update()
751+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
752+
progress_bar.update()
753753

754754
if XLA_AVAILABLE:
755755
xm.mark_step()

0 commit comments

Comments
 (0)