|
11 | 11 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
12 | 12 |
|
13 | 13 | from ...image_processor import VaeImageProcessor |
14 | | -from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin |
| 14 | +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin |
15 | 15 | from ...models import AutoencoderKL, UNet2DConditionModel |
16 | 16 | from ...models.lora import adjust_lora_scale_text_encoder |
17 | 17 | from ...schedulers import KarrasDiffusionSchedulers |
18 | | -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers |
| 18 | +from ...utils import ( |
| 19 | + USE_PEFT_BACKEND, |
| 20 | + BaseOutput, |
| 21 | + is_torch_xla_available, |
| 22 | + logging, |
| 23 | + scale_lora_layers, |
| 24 | + unscale_lora_layers, |
| 25 | +) |
19 | 26 | from ...utils.torch_utils import randn_tensor |
20 | 27 | from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin |
21 | 28 | from ..stable_diffusion import StableDiffusionSafetyChecker |
22 | 29 |
|
23 | 30 |
|
| 31 | +if is_torch_xla_available(): |
| 32 | + import torch_xla.core.xla_model as xm |
| 33 | + |
| 34 | + XLA_AVAILABLE = True |
| 35 | +else: |
| 36 | + XLA_AVAILABLE = False |
| 37 | + |
24 | 38 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
25 | 39 |
|
26 | 40 |
|
@@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s |
282 | 296 |
|
283 | 297 |
|
284 | 298 | class TextToVideoZeroPipeline( |
285 | | - DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin |
| 299 | + DiffusionPipeline, |
| 300 | + StableDiffusionMixin, |
| 301 | + TextualInversionLoaderMixin, |
| 302 | + StableDiffusionLoraLoaderMixin, |
| 303 | + FromSingleFileMixin, |
286 | 304 | ): |
287 | 305 | r""" |
288 | 306 | Pipeline for zero-shot text-to-video generation using Stable Diffusion. |
@@ -440,6 +458,10 @@ def backward_loop( |
440 | 458 | if callback is not None and i % callback_steps == 0: |
441 | 459 | step_idx = i // getattr(self.scheduler, "order", 1) |
442 | 460 | callback(step_idx, t, latents) |
| 461 | + |
| 462 | + if XLA_AVAILABLE: |
| 463 | + xm.mark_step() |
| 464 | + |
443 | 465 | return latents.clone().detach() |
444 | 466 |
|
445 | 467 | # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs |
|
0 commit comments