File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed
src/diffusers/pipelines/stable_audio Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change 2626from ...models .embeddings import get_1d_rotary_pos_embed
2727from ...schedulers import EDMDPMSolverMultistepScheduler
2828from ...utils import (
29+ is_torch_xla_available ,
2930 logging ,
3031 replace_example_docstring ,
3132)
3233from ...utils .torch_utils import randn_tensor
3334from ..pipeline_utils import AudioPipelineOutput , DiffusionPipeline
3435from .modeling_stable_audio import StableAudioProjectionModel
3536
37+ if is_torch_xla_available ():
38+ import torch_xla .core .xla_model as xm
39+
40+ XLA_AVAILABLE = True
41+ else :
42+ XLA_AVAILABLE = False
3643
3744logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3845
@@ -725,6 +732,9 @@ def __call__(
725732 if callback is not None and i % callback_steps == 0 :
726733 step_idx = i // getattr (self .scheduler , "order" , 1 )
727734 callback (step_idx , t , latents )
735+
736+ if XLA_AVAILABLE :
737+ xm .mark_step ()
728738
729739 # 9. Post-processing
730740 if not output_type == "latent" :
You can’t perform that action at this time.
0 commit comments