Skip to content

Commit eee1e5f

Browse files
authored
Merge branch 'main' into instruct-pix2pix-xla-single-file
2 parents d98623c + a17832b commit eee1e5f

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

examples/community/rerender_a_video.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,17 @@
3030
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
3131
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3232
from diffusers.schedulers import KarrasDiffusionSchedulers
33-
from diffusers.utils import BaseOutput, deprecate, logging
33+
from diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging
3434
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
3535

3636

37+
if is_torch_xla_available():
38+
import torch_xla.core.xla_model as xm
39+
40+
XLA_AVAILABLE = True
41+
else:
42+
XLA_AVAILABLE = False
43+
3744
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3845

3946

@@ -1100,6 +1107,9 @@ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None):
11001107
if callback is not None and i % callback_steps == 0:
11011108
callback(i, t, latents)
11021109

1110+
if XLA_AVAILABLE:
1111+
xm.mark_step()
1112+
11031113
return latents
11041114

11051115
if mask_start_t <= mask_end_t:

0 commit comments

Comments
 (0)