Skip to content

Commit c5b4a3c

Browse files
test flux
1 parent 14f6464 commit c5b4a3c

File tree

2 files changed

+51
-40
lines changed

2 files changed

+51
-40
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,17 @@
3131
)
3232
from ...models.modeling_utils import ModelMixin
3333
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34-
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
34+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_xla_available
3535
from ...utils.torch_utils import maybe_allow_in_graph
3636
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
3737
from ..modeling_outputs import Transformer2DModelOutput
3838

39+
if is_torch_xla_available():
40+
import torch_xla.core.xla_model as xm
41+
42+
XLA_AVAILABLE = True
43+
else:
44+
XLA_AVAILABLE = False
3945

4046
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4147

@@ -498,6 +504,8 @@ def custom_forward(*inputs):
498504
temb=temb,
499505
image_rotary_emb=image_rotary_emb,
500506
)
507+
if XLA_AVAILABLE:
508+
xm.mark_step()
501509

502510
# controlnet residual
503511
if controlnet_block_samples is not None:
@@ -534,6 +542,8 @@ def custom_forward(*inputs):
534542
temb=temb,
535543
image_rotary_emb=image_rotary_emb,
536544
)
545+
if XLA_AVAILABLE:
546+
xm.mark_step()
537547

538548
# controlnet residual
539549
if controlnet_single_block_samples is not None:

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -708,47 +708,48 @@ def __call__(
708708
guidance = None
709709

710710
# 6. Denoising loop
711-
with self.progress_bar(total=num_inference_steps) as progress_bar:
712-
for i, t in enumerate(timesteps):
713-
if self.interrupt:
714-
continue
715-
716-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
717-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
718-
719-
noise_pred = self.transformer(
720-
hidden_states=latents,
721-
timestep=timestep / 1000,
722-
guidance=guidance,
723-
pooled_projections=pooled_prompt_embeds,
724-
encoder_hidden_states=prompt_embeds,
725-
txt_ids=text_ids,
726-
img_ids=latent_image_ids,
727-
joint_attention_kwargs=self.joint_attention_kwargs,
728-
return_dict=False,
729-
)[0]
730-
731-
# compute the previous noisy sample x_t -> x_t-1
732-
latents_dtype = latents.dtype
733-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
734-
735-
if latents.dtype != latents_dtype:
736-
if torch.backends.mps.is_available():
737-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
738-
latents = latents.to(latents_dtype)
739-
740-
if callback_on_step_end is not None:
741-
callback_kwargs = {}
742-
for k in callback_on_step_end_tensor_inputs:
743-
callback_kwargs[k] = locals()[k]
744-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
745-
746-
latents = callback_outputs.pop("latents", latents)
747-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
711+
#with self.progress_bar(total=num_inference_steps) as progress_bar:
712+
for i, t in enumerate(timesteps):
713+
if self.interrupt:
714+
continue
715+
716+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
717+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
718+
719+
noise_pred = self.transformer(
720+
hidden_states=latents,
721+
timestep=timestep / 1000,
722+
guidance=guidance,
723+
pooled_projections=pooled_prompt_embeds,
724+
encoder_hidden_states=prompt_embeds,
725+
txt_ids=text_ids,
726+
img_ids=latent_image_ids,
727+
joint_attention_kwargs=self.joint_attention_kwargs,
728+
return_dict=False,
729+
)[0]
730+
731+
# compute the previous noisy sample x_t -> x_t-1
732+
latents_dtype = latents.dtype
733+
734+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
735+
736+
if latents.dtype != latents_dtype:
737+
if torch.backends.mps.is_available():
738+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
739+
latents = latents.to(latents_dtype)
740+
741+
if callback_on_step_end is not None:
742+
callback_kwargs = {}
743+
for k in callback_on_step_end_tensor_inputs:
744+
callback_kwargs[k] = locals()[k]
745+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
746+
747+
latents = callback_outputs.pop("latents", latents)
748+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
748749

749750
# call the callback, if provided
750-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
751-
progress_bar.update()
751+
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
752+
# progress_bar.update()
752753

753754
if XLA_AVAILABLE:
754755
xm.mark_step()

0 commit comments

Comments
 (0)