Skip to content

Commit d98623c

Browse files
committed
Add torch_xla and from_single_file to instruct-pix2pix
1 parent c28db0a commit d98623c

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
110110
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
111111
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
112+
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
112113
}
113114

114115
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -165,6 +166,7 @@
165166
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
166167
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
167168
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
169+
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
168170
}
169171

170172
# Use to configure model sample size when original config is provided
@@ -633,6 +635,12 @@ def infer_diffusers_model_type(checkpoint):
633635
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
634636
model_type = "hunyuan-video"
635637

638+
elif (
639+
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
640+
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
641+
):
642+
model_type = "instruct-pix2pix"
643+
636644
else:
637645
model_type = "v1"
638646

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,23 @@
2222

2323
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2424
from ...image_processor import PipelineImageInput, VaeImageProcessor
25-
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
25+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
2626
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
2727
from ...schedulers import KarrasDiffusionSchedulers
28-
from ...utils import PIL_INTERPOLATION, deprecate, logging
28+
from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
2929
from ...utils.torch_utils import randn_tensor
3030
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3131
from . import StableDiffusionPipelineOutput
3232
from .safety_checker import StableDiffusionSafetyChecker
3333

3434

35+
if is_torch_xla_available():
36+
import torch_xla.core.xla_model as xm
37+
38+
XLA_AVAILABLE = True
39+
else:
40+
XLA_AVAILABLE = False
41+
3542
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3643

3744

@@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline(
7986
TextualInversionLoaderMixin,
8087
StableDiffusionLoraLoaderMixin,
8188
IPAdapterMixin,
89+
FromSingleFileMixin,
8290
):
8391
r"""
8492
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
@@ -457,6 +465,9 @@ def __call__(
457465
step_idx = i // getattr(self.scheduler, "order", 1)
458466
callback(step_idx, t, latents)
459467

468+
if XLA_AVAILABLE:
469+
xm.mark_step()
470+
460471
if not output_type == "latent":
461472
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
462473
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

0 commit comments

Comments
 (0)