| 
11 | 11 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer  | 
12 | 12 | 
 
  | 
13 | 13 | from ...image_processor import VaeImageProcessor  | 
14 |  | -from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin  | 
 | 14 | +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin  | 
15 | 15 | from ...models import AutoencoderKL, UNet2DConditionModel  | 
16 | 16 | from ...models.lora import adjust_lora_scale_text_encoder  | 
17 | 17 | from ...schedulers import KarrasDiffusionSchedulers  | 
18 |  | -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers  | 
 | 18 | +from ...utils import (  | 
 | 19 | +    USE_PEFT_BACKEND,  | 
 | 20 | +    BaseOutput,  | 
 | 21 | +    is_torch_xla_available,  | 
 | 22 | +    logging,  | 
 | 23 | +    scale_lora_layers,  | 
 | 24 | +    unscale_lora_layers,  | 
 | 25 | +)  | 
19 | 26 | from ...utils.torch_utils import randn_tensor  | 
20 | 27 | from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  | 
21 | 28 | from ..stable_diffusion import StableDiffusionSafetyChecker  | 
22 | 29 | 
 
  | 
23 | 30 | 
 
  | 
 | 31 | +if is_torch_xla_available():  | 
 | 32 | +    import torch_xla.core.xla_model as xm  | 
 | 33 | + | 
 | 34 | +    XLA_AVAILABLE = True  | 
 | 35 | +else:  | 
 | 36 | +    XLA_AVAILABLE = False  | 
 | 37 | + | 
24 | 38 | logger = logging.get_logger(__name__)  # pylint: disable=invalid-name  | 
25 | 39 | 
 
  | 
26 | 40 | 
 
  | 
@@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s  | 
282 | 296 | 
 
  | 
283 | 297 | 
 
  | 
284 | 298 | class TextToVideoZeroPipeline(  | 
285 |  | -    DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin  | 
 | 299 | +    DiffusionPipeline,  | 
 | 300 | +    StableDiffusionMixin,  | 
 | 301 | +    TextualInversionLoaderMixin,  | 
 | 302 | +    StableDiffusionLoraLoaderMixin,  | 
 | 303 | +    FromSingleFileMixin,  | 
286 | 304 | ):  | 
287 | 305 |     r"""  | 
288 | 306 |     Pipeline for zero-shot text-to-video generation using Stable Diffusion.  | 
@@ -440,6 +458,10 @@ def backward_loop(  | 
440 | 458 |                     if callback is not None and i % callback_steps == 0:  | 
441 | 459 |                         step_idx = i // getattr(self.scheduler, "order", 1)  | 
442 | 460 |                         callback(step_idx, t, latents)  | 
 | 461 | + | 
 | 462 | +                if XLA_AVAILABLE:  | 
 | 463 | +                    xm.mark_step()  | 
 | 464 | + | 
443 | 465 |         return latents.clone().detach()  | 
444 | 466 | 
 
  | 
445 | 467 |     # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs  | 
 | 
0 commit comments