File tree Expand file tree Collapse file tree 1 file changed +11
-1
lines changed
src/diffusers/pipelines/aura_flow Expand file tree Collapse file tree 1 file changed +11
-1
lines changed Original file line number Diff line number Diff line change 21
21
from ...models import AuraFlowTransformer2DModel , AutoencoderKL
22
22
from ...models .attention_processor import AttnProcessor2_0 , FusedAttnProcessor2_0 , XFormersAttnProcessor
23
23
from ...schedulers import FlowMatchEulerDiscreteScheduler
24
- from ...utils import logging , replace_example_docstring
24
+ from ...utils import is_torch_xla_available , logging , replace_example_docstring
25
25
from ...utils .torch_utils import randn_tensor
26
26
from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput
27
27
28
28
29
+ if is_torch_xla_available ():
30
+ import torch_xla .core .xla_model as xm
31
+
32
+ XLA_AVAILABLE = True
33
+ else :
34
+ XLA_AVAILABLE = False
35
+
29
36
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
30
37
31
38
@@ -564,6 +571,9 @@ def __call__(
564
571
if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
565
572
progress_bar .update ()
566
573
574
+ if XLA_AVAILABLE :
575
+ xm .mark_step ()
576
+
567
577
if output_type == "latent" :
568
578
image = latents
569
579
else :
You can’t perform that action at this time.
0 commit comments