File tree Expand file tree Collapse file tree 1 file changed +10
-1
lines changed
src/diffusers/pipelines/aura_flow Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Original file line number Diff line number Diff line change 2121from ...models import AuraFlowTransformer2DModel , AutoencoderKL
2222from ...models .attention_processor import AttnProcessor2_0 , FusedAttnProcessor2_0 , XFormersAttnProcessor
2323from ...schedulers import FlowMatchEulerDiscreteScheduler
24- from ...utils import logging , replace_example_docstring
24+ from ...utils import is_torch_xla_available , logging , replace_example_docstring
2525from ...utils .torch_utils import randn_tensor
2626from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput
2727
28+ if is_torch_xla_available ():
29+ import torch_xla .core .xla_model as xm
30+
31+ XLA_AVAILABLE = True
32+ else :
33+ XLA_AVAILABLE = False
2834
2935logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3036
@@ -564,6 +570,9 @@ def __call__(
564570 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
565571 progress_bar .update ()
566572
573+ if XLA_AVAILABLE :
574+ xm .mark_step ()
575+
567576 if output_type == "latent" :
568577 image = latents
569578 else :
You can’t perform that action at this time.
0 commit comments