Skip to content

Commit 96317ee

Browse files
committed
testing flux
1 parent a26d570 commit 96317ee

File tree

1 file changed

+53
-53
lines changed

1 file changed

+53
-53
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -874,68 +874,68 @@ def __call__(
874874
)
875875

876876
# 6. Denoising loop
877-
with self.progress_bar(total=num_inference_steps) as progress_bar:
878-
for i, t in enumerate(timesteps):
879-
if self.interrupt:
880-
continue
881-
882-
if image_embeds is not None:
883-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
884-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
885-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
886-
887-
noise_pred = self.transformer(
877+
#with self.progress_bar(total=num_inference_steps) as progress_bar:
878+
for i, t in enumerate(timesteps):
879+
if self.interrupt:
880+
continue
881+
882+
if image_embeds is not None:
883+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
884+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
885+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
886+
887+
noise_pred = self.transformer(
888+
hidden_states=latents,
889+
timestep=timestep / 1000,
890+
guidance=guidance,
891+
pooled_projections=pooled_prompt_embeds,
892+
encoder_hidden_states=prompt_embeds,
893+
txt_ids=text_ids,
894+
img_ids=latent_image_ids,
895+
joint_attention_kwargs=self.joint_attention_kwargs,
896+
return_dict=False,
897+
)[0]
898+
899+
if do_true_cfg:
900+
if negative_image_embeds is not None:
901+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
902+
neg_noise_pred = self.transformer(
888903
hidden_states=latents,
889904
timestep=timestep / 1000,
890905
guidance=guidance,
891-
pooled_projections=pooled_prompt_embeds,
892-
encoder_hidden_states=prompt_embeds,
906+
pooled_projections=negative_pooled_prompt_embeds,
907+
encoder_hidden_states=negative_prompt_embeds,
893908
txt_ids=text_ids,
894909
img_ids=latent_image_ids,
895910
joint_attention_kwargs=self.joint_attention_kwargs,
896911
return_dict=False,
897912
)[0]
913+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
914+
915+
# compute the previous noisy sample x_t -> x_t-1
916+
latents_dtype = latents.dtype
917+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
918+
919+
if latents.dtype != latents_dtype:
920+
if torch.backends.mps.is_available():
921+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
922+
latents = latents.to(latents_dtype)
923+
924+
if callback_on_step_end is not None:
925+
callback_kwargs = {}
926+
for k in callback_on_step_end_tensor_inputs:
927+
callback_kwargs[k] = locals()[k]
928+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
929+
930+
latents = callback_outputs.pop("latents", latents)
931+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
932+
933+
# call the callback, if provided
934+
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
935+
# progress_bar.update()
898936

899-
if do_true_cfg:
900-
if negative_image_embeds is not None:
901-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
902-
neg_noise_pred = self.transformer(
903-
hidden_states=latents,
904-
timestep=timestep / 1000,
905-
guidance=guidance,
906-
pooled_projections=negative_pooled_prompt_embeds,
907-
encoder_hidden_states=negative_prompt_embeds,
908-
txt_ids=text_ids,
909-
img_ids=latent_image_ids,
910-
joint_attention_kwargs=self.joint_attention_kwargs,
911-
return_dict=False,
912-
)[0]
913-
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
914-
915-
# compute the previous noisy sample x_t -> x_t-1
916-
latents_dtype = latents.dtype
917-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
918-
919-
if latents.dtype != latents_dtype:
920-
if torch.backends.mps.is_available():
921-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
922-
latents = latents.to(latents_dtype)
923-
924-
if callback_on_step_end is not None:
925-
callback_kwargs = {}
926-
for k in callback_on_step_end_tensor_inputs:
927-
callback_kwargs[k] = locals()[k]
928-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
929-
930-
latents = callback_outputs.pop("latents", latents)
931-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
932-
933-
# call the callback, if provided
934-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
935-
progress_bar.update()
936-
937-
if XLA_AVAILABLE:
938-
xm.mark_step()
937+
if XLA_AVAILABLE:
938+
xm.mark_step()
939939

940940
if output_type == "latent":
941941
image = latents

0 commit comments

Comments
 (0)