From f9ea86b9659994bfdafbea205757cc0f6c8d686b Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Mon, 23 Dec 2024 22:47:10 +0530 Subject: [PATCH] [Add] torch_xla support in pipeline_sana.py --- src/diffusers/pipelines/sana/pipeline_sana.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index fe3c9e13aa31..c90dec4d41b3 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -31,6 +31,7 @@ USE_PEFT_BACKEND, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -46,6 +47,13 @@ from .pipeline_output import SanaPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): @@ -864,6 +872,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if output_type == "latent": image = latents else: