File tree Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Original file line number Diff line number Diff line change 2121from ...models import AuraFlowTransformer2DModel , AutoencoderKL
2222from ...models .attention_processor import AttnProcessor2_0 , FusedAttnProcessor2_0 , XFormersAttnProcessor
2323from ...schedulers import FlowMatchEulerDiscreteScheduler
24- from ...utils import logging , replace_example_docstring
24+ from ...utils import is_torch_xla_available , logging , replace_example_docstring
2525from ...utils .torch_utils import randn_tensor
2626from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput
2727
2828
29+ if is_torch_xla_available ():
30+ import torch_xla .core .xla_model as xm
31+
32+ XLA_AVAILABLE = True
33+ else :
34+ XLA_AVAILABLE = False
35+
2936logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3037
3138
@@ -564,6 +571,9 @@ def __call__(
564571 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
565572 progress_bar .update ()
566573
574+ if XLA_AVAILABLE :
575+ xm .mark_step ()
576+
567577 if output_type == "latent" :
568578 image = latents
569579 else :
Original file line number Diff line number Diff line change 3131 USE_PEFT_BACKEND ,
3232 is_bs4_available ,
3333 is_ftfy_available ,
34+ is_torch_xla_available ,
3435 logging ,
3536 replace_example_docstring ,
3637 scale_lora_layers ,
4647from .pipeline_output import SanaPipelineOutput
4748
4849
50+ if is_torch_xla_available ():
51+ import torch_xla .core .xla_model as xm
52+
53+ XLA_AVAILABLE = True
54+ else :
55+ XLA_AVAILABLE = False
56+
4957logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
5058
5159if is_bs4_available ():
@@ -864,6 +872,9 @@ def __call__(
864872 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
865873 progress_bar .update ()
866874
875+ if XLA_AVAILABLE :
876+ xm .mark_step ()
877+
867878 if output_type == "latent" :
868879 image = latents
869880 else :
You can’t perform that action at this time.
0 commit comments