Skip to content

Commit 4f49d09

Browse files
committed
Add torch_xla support to pipeline_aura_flow.py
1 parent 4b55713 commit 4f49d09

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,16 @@
2121
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
2222
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
2323
from ...schedulers import FlowMatchEulerDiscreteScheduler
24-
from ...utils import logging, replace_example_docstring
24+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2525
from ...utils.torch_utils import randn_tensor
2626
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2727

28+
if is_torch_xla_available():
29+
import torch_xla.core.xla_model as xm
30+
31+
XLA_AVAILABLE = True
32+
else:
33+
XLA_AVAILABLE = False
2834

2935
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3036

@@ -564,6 +570,9 @@ def __call__(
564570
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
565571
progress_bar.update()
566572

573+
if XLA_AVAILABLE:
574+
xm.mark_step()
575+
567576
if output_type == "latent":
568577
image = latents
569578
else:

0 commit comments

Comments
 (0)