Skip to content

Commit 9fe3ce9

Browse files
fancy45daddyhlky
authored andcommitted
Update pipeline_controlnet.py add support for pytorch_xla (#10222)
* Update pipeline_controlnet.py * make style --------- Co-authored-by: hlky <[email protected]>
1 parent 83ac101 commit 9fe3ce9

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ...utils import (
3232
USE_PEFT_BACKEND,
3333
deprecate,
34+
is_torch_xla_available,
3435
logging,
3536
replace_example_docstring,
3637
scale_lora_layers,
@@ -42,6 +43,13 @@
4243
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
4344

4445

46+
if is_torch_xla_available():
47+
import torch_xla.core.xla_model as xm
48+
49+
XLA_AVAILABLE = True
50+
else:
51+
XLA_AVAILABLE = False
52+
4553
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4654

4755

@@ -1323,6 +1331,8 @@ def __call__(
13231331
step_idx = i // getattr(self.scheduler, "order", 1)
13241332
callback(step_idx, t, latents)
13251333

1334+
if XLA_AVAILABLE:
1335+
xm.mark_step()
13261336
# If we do sequential model offloading, let's offload unet and controlnet
13271337
# manually for max memory savings
13281338
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

0 commit comments

Comments
 (0)