|  | 
| 14 | 14 | 
 | 
| 15 | 15 | from typing import Callable, List, Optional, Union | 
| 16 | 16 | 
 | 
| 17 |  | -import numpy as np | 
| 18 | 17 | import PIL.Image | 
| 19 | 18 | import torch | 
| 20 |  | -from PIL import Image | 
| 21 | 19 | from transformers import ( | 
| 22 | 20 |     XLMRobertaTokenizer, | 
| 23 | 21 | ) | 
| 24 | 22 | 
 | 
|  | 23 | +from ...image_processor import VaeImageProcessor | 
| 25 | 24 | from ...models import UNet2DConditionModel, VQModel | 
| 26 | 25 | from ...schedulers import DDIMScheduler | 
| 27 | 26 | from ...utils import ( | 
| @@ -95,15 +94,6 @@ def get_new_h_w(h, w, scale_factor=8): | 
| 95 | 94 |     return new_h * scale_factor, new_w * scale_factor | 
| 96 | 95 | 
 | 
| 97 | 96 | 
 | 
| 98 |  | -def prepare_image(pil_image, w=512, h=512): | 
| 99 |  | -    pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) | 
| 100 |  | -    arr = np.array(pil_image.convert("RGB")) | 
| 101 |  | -    arr = arr.astype(np.float32) / 127.5 - 1 | 
| 102 |  | -    arr = np.transpose(arr, [2, 0, 1]) | 
| 103 |  | -    image = torch.from_numpy(arr).unsqueeze(0) | 
| 104 |  | -    return image | 
| 105 |  | - | 
| 106 |  | - | 
| 107 | 97 | class KandinskyImg2ImgPipeline(DiffusionPipeline): | 
| 108 | 98 |     """ | 
| 109 | 99 |     Pipeline for image-to-image generation using Kandinsky | 
| @@ -144,6 +134,12 @@ def __init__( | 
| 144 | 134 |             movq=movq, | 
| 145 | 135 |         ) | 
| 146 | 136 |         self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) | 
|  | 137 | +        self.image_processor = VaeImageProcessor( | 
|  | 138 | +            vae_scale_factor=self.movq_scale_factor, | 
|  | 139 | +            vae_latent_channels=self.movq.config.latent_channels, | 
|  | 140 | +            resample="bicubic", | 
|  | 141 | +            reducing_gap=1, | 
|  | 142 | +        ) | 
| 147 | 143 | 
 | 
| 148 | 144 |     def get_timesteps(self, num_inference_steps, strength, device): | 
| 149 | 145 |         # get the original timestep using init_timestep | 
| @@ -417,7 +413,7 @@ def __call__( | 
| 417 | 413 |                 f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support  PIL image and pytorch tensor" | 
| 418 | 414 |             ) | 
| 419 | 415 | 
 | 
| 420 |  | -        image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) | 
|  | 416 | +        image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0) | 
| 421 | 417 |         image = image.to(dtype=prompt_embeds.dtype, device=device) | 
| 422 | 418 | 
 | 
| 423 | 419 |         latents = self.movq.encode(image)["latents"] | 
| @@ -498,13 +494,7 @@ def __call__( | 
| 498 | 494 |         if output_type not in ["pt", "np", "pil"]: | 
| 499 | 495 |             raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") | 
| 500 | 496 | 
 | 
| 501 |  | -        if output_type in ["np", "pil"]: | 
| 502 |  | -            image = image * 0.5 + 0.5 | 
| 503 |  | -            image = image.clamp(0, 1) | 
| 504 |  | -            image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 
| 505 |  | - | 
| 506 |  | -        if output_type == "pil": | 
| 507 |  | -            image = self.numpy_to_pil(image) | 
|  | 497 | +        image = self.image_processor.postprocess(image, output_type) | 
| 508 | 498 | 
 | 
| 509 | 499 |         if not return_dict: | 
| 510 | 500 |             return (image,) | 
|  | 
0 commit comments