Skip to content

Commit a91138d

Browse files
authored
Merge branch 'main' into improve-lora-warning-msg
2 parents da96621 + 83da817 commit a91138d

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,18 @@
2121
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
2222
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
2323
from ...schedulers import FlowMatchEulerDiscreteScheduler
24-
from ...utils import logging, replace_example_docstring
24+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2525
from ...utils.torch_utils import randn_tensor
2626
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2727

2828

29+
if is_torch_xla_available():
30+
import torch_xla.core.xla_model as xm
31+
32+
XLA_AVAILABLE = True
33+
else:
34+
XLA_AVAILABLE = False
35+
2936
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3037

3138

@@ -564,6 +571,9 @@ def __call__(
564571
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
565572
progress_bar.update()
566573

574+
if XLA_AVAILABLE:
575+
xm.mark_step()
576+
567577
if output_type == "latent":
568578
image = latents
569579
else:

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
USE_PEFT_BACKEND,
3232
is_bs4_available,
3333
is_ftfy_available,
34+
is_torch_xla_available,
3435
logging,
3536
replace_example_docstring,
3637
scale_lora_layers,
@@ -46,6 +47,13 @@
4647
from .pipeline_output import SanaPipelineOutput
4748

4849

50+
if is_torch_xla_available():
51+
import torch_xla.core.xla_model as xm
52+
53+
XLA_AVAILABLE = True
54+
else:
55+
XLA_AVAILABLE = False
56+
4957
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5058

5159
if is_bs4_available():
@@ -864,6 +872,9 @@ def __call__(
864872
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
865873
progress_bar.update()
866874

875+
if XLA_AVAILABLE:
876+
xm.mark_step()
877+
867878
if output_type == "latent":
868879
image = latents
869880
else:

0 commit comments

Comments
 (0)