Skip to content

Commit b24c6a8

Browse files
committed
update
1 parent e8e29ac commit b24c6a8

File tree

4 files changed

+12
-22
lines changed

4 files changed

+12
-22
lines changed

src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414

1515
from typing import Callable, List, Optional, Union
1616

17-
import numpy as np
1817
import PIL.Image
1918
import torch
20-
from PIL import Image
2119
from transformers import (
2220
XLMRobertaTokenizer,
2321
)
2422

23+
from ...image_processor import VaeImageProcessor
2524
from ...models import UNet2DConditionModel, VQModel
2625
from ...schedulers import DDIMScheduler
2726
from ...utils import (
@@ -95,15 +94,6 @@ def get_new_h_w(h, w, scale_factor=8):
9594
return new_h * scale_factor, new_w * scale_factor
9695

9796

98-
def prepare_image(pil_image, w=512, h=512):
99-
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
100-
arr = np.array(pil_image.convert("RGB"))
101-
arr = arr.astype(np.float32) / 127.5 - 1
102-
arr = np.transpose(arr, [2, 0, 1])
103-
image = torch.from_numpy(arr).unsqueeze(0)
104-
return image
105-
106-
10797
class KandinskyImg2ImgPipeline(DiffusionPipeline):
10898
"""
10999
Pipeline for image-to-image generation using Kandinsky
@@ -144,6 +134,12 @@ def __init__(
144134
movq=movq,
145135
)
146136
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
137+
self.image_processor = VaeImageProcessor(
138+
vae_scale_factor=self.movq_scale_factor,
139+
vae_latent_channels=self.movq.config.latent_channels,
140+
resample="bicubic",
141+
reducing_gap=1,
142+
)
147143

148144
def get_timesteps(self, num_inference_steps, strength, device):
149145
# get the original timestep using init_timestep
@@ -417,7 +413,7 @@ def __call__(
417413
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
418414
)
419415

420-
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
416+
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
421417
image = image.to(dtype=prompt_embeds.dtype, device=device)
422418

423419
latents = self.movq.encode(image)["latents"]
@@ -498,13 +494,7 @@ def __call__(
498494
if output_type not in ["pt", "np", "pil"]:
499495
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
500496

501-
if output_type in ["np", "pil"]:
502-
image = image * 0.5 + 0.5
503-
image = image.clamp(0, 1)
504-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
505-
506-
if output_type == "pil":
507-
image = self.numpy_to_pil(image)
497+
image = self.image_processor.postprocess(image, output_type)
508498

509499
if not return_dict:
510500
return (image,)

src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def __call__(
362362
if output_type not in ["pt", "np", "pil"]:
363363
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
364364

365-
image = self.image_processor.postprocess(image, output_type=output_type)
365+
image = self.image_processor.postprocess(image, output_type)
366366

367367
if not return_dict:
368368
return (image,)

src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def __call__(
370370
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
371371
else:
372372
image = latents
373-
image = self.image_processor.postprocess(image, output_type=output_type)
373+
image = self.image_processor.postprocess(image, output_type)
374374

375375
# Offload all models
376376
self.maybe_free_model_hooks()

src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def __call__(
623623
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
624624
else:
625625
image = latents
626-
image = self.image_processor.postprocess(image, output_type=output_type)
626+
image = self.image_processor.postprocess(image, output_type)
627627

628628
self.maybe_free_model_hooks()
629629

0 commit comments

Comments
 (0)