|
22 | 22 |
|
23 | 23 | from ...callbacks import MultiPipelineCallbacks, PipelineCallback |
24 | 24 | from ...image_processor import PipelineImageInput, VaeImageProcessor |
25 | | -from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin |
| 25 | +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin |
26 | 26 | from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel |
27 | 27 | from ...schedulers import KarrasDiffusionSchedulers |
28 | | -from ...utils import PIL_INTERPOLATION, deprecate, logging |
| 28 | +from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging |
29 | 29 | from ...utils.torch_utils import randn_tensor |
30 | 30 | from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin |
31 | 31 | from . import StableDiffusionPipelineOutput |
32 | 32 | from .safety_checker import StableDiffusionSafetyChecker |
33 | 33 |
|
34 | 34 |
|
| 35 | +if is_torch_xla_available(): |
| 36 | + import torch_xla.core.xla_model as xm |
| 37 | + |
| 38 | + XLA_AVAILABLE = True |
| 39 | +else: |
| 40 | + XLA_AVAILABLE = False |
| 41 | + |
35 | 42 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
36 | 43 |
|
37 | 44 |
|
@@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline( |
79 | 86 | TextualInversionLoaderMixin, |
80 | 87 | StableDiffusionLoraLoaderMixin, |
81 | 88 | IPAdapterMixin, |
| 89 | + FromSingleFileMixin, |
82 | 90 | ): |
83 | 91 | r""" |
84 | 92 | Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). |
@@ -457,6 +465,9 @@ def __call__( |
457 | 465 | step_idx = i // getattr(self.scheduler, "order", 1) |
458 | 466 | callback(step_idx, t, latents) |
459 | 467 |
|
| 468 | + if XLA_AVAILABLE: |
| 469 | + xm.mark_step() |
| 470 | + |
460 | 471 | if not output_type == "latent": |
461 | 472 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
462 | 473 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
|
0 commit comments