|  | 
| 22 | 22 | 
 | 
| 23 | 23 | from ...callbacks import MultiPipelineCallbacks, PipelineCallback | 
| 24 | 24 | from ...image_processor import PipelineImageInput, VaeImageProcessor | 
| 25 |  | -from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin | 
|  | 25 | +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin | 
| 26 | 26 | from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel | 
| 27 | 27 | from ...schedulers import KarrasDiffusionSchedulers | 
| 28 |  | -from ...utils import PIL_INTERPOLATION, deprecate, logging | 
|  | 28 | +from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging | 
| 29 | 29 | from ...utils.torch_utils import randn_tensor | 
| 30 | 30 | from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin | 
| 31 | 31 | from . import StableDiffusionPipelineOutput | 
| 32 | 32 | from .safety_checker import StableDiffusionSafetyChecker | 
| 33 | 33 | 
 | 
| 34 | 34 | 
 | 
|  | 35 | +if is_torch_xla_available(): | 
|  | 36 | +    import torch_xla.core.xla_model as xm | 
|  | 37 | + | 
|  | 38 | +    XLA_AVAILABLE = True | 
|  | 39 | +else: | 
|  | 40 | +    XLA_AVAILABLE = False | 
|  | 41 | + | 
| 35 | 42 | logger = logging.get_logger(__name__)  # pylint: disable=invalid-name | 
| 36 | 43 | 
 | 
| 37 | 44 | 
 | 
| @@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline( | 
| 79 | 86 |     TextualInversionLoaderMixin, | 
| 80 | 87 |     StableDiffusionLoraLoaderMixin, | 
| 81 | 88 |     IPAdapterMixin, | 
|  | 89 | +    FromSingleFileMixin, | 
| 82 | 90 | ): | 
| 83 | 91 |     r""" | 
| 84 | 92 |     Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). | 
| @@ -457,6 +465,9 @@ def __call__( | 
| 457 | 465 |                         step_idx = i // getattr(self.scheduler, "order", 1) | 
| 458 | 466 |                         callback(step_idx, t, latents) | 
| 459 | 467 | 
 | 
|  | 468 | +                if XLA_AVAILABLE: | 
|  | 469 | +                    xm.mark_step() | 
|  | 470 | + | 
| 460 | 471 |         if not output_type == "latent": | 
| 461 | 472 |             image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | 
| 462 | 473 |             image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | 
|  | 
0 commit comments