Skip to content

Commit 4f3bd13

Browse files
committed
bug fix kadinsky pipeline
1 parent 1001425 commit 4f3bd13

File tree

4 files changed

+33
-104
lines changed

4 files changed

+33
-104
lines changed

src/diffusers/image_processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
vae_scale_factor: int = 8,
117117
vae_latent_channels: int = 4,
118118
resample: str = "lanczos",
119+
reducing_gap: int = None,
119120
do_normalize: bool = True,
120121
do_binarize: bool = False,
121122
do_convert_rgb: bool = False,
@@ -498,7 +499,7 @@ def resize(
498499
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
499500
if isinstance(image, PIL.Image.Image):
500501
if resize_mode == "default":
501-
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
502+
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample], reducing_gap=self.config.reducing_gap)
502503
elif resize_mode == "fill":
503504
image = self._resize_and_fill(image, width, height)
504505
elif resize_mode == "crop":

src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from ...utils.torch_utils import randn_tensor
2929
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
30-
30+
from ...image_processor import VaeImageProcessor
3131

3232
if is_torch_xla_available():
3333
import torch_xla.core.xla_model as xm
@@ -105,27 +105,6 @@
105105
"""
106106

107107

108-
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
109-
def downscale_height_and_width(height, width, scale_factor=8):
110-
new_height = height // scale_factor**2
111-
if height % scale_factor**2 != 0:
112-
new_height += 1
113-
new_width = width // scale_factor**2
114-
if width % scale_factor**2 != 0:
115-
new_width += 1
116-
return new_height * scale_factor, new_width * scale_factor
117-
118-
119-
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
120-
def prepare_image(pil_image, w=512, h=512):
121-
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
122-
arr = np.array(pil_image.convert("RGB"))
123-
arr = arr.astype(np.float32) / 127.5 - 1
124-
arr = np.transpose(arr, [2, 0, 1])
125-
image = torch.from_numpy(arr).unsqueeze(0)
126-
return image
127-
128-
129108
class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
130109
"""
131110
Pipeline for image-to-image generation using Kandinsky
@@ -157,7 +136,13 @@ def __init__(
157136
scheduler=scheduler,
158137
movq=movq,
159138
)
160-
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
139+
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
140+
self.image_processor = VaeImageProcessor(
141+
vae_scale_factor = movq_scale_factor,
142+
vae_latent_channels = self.movq.config.latent_channels,
143+
resample = "bicubic",
144+
reducing_gap = 1,
145+
)
161146

162147
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
163148
def get_timesteps(self, num_inference_steps, strength, device):
@@ -316,15 +301,14 @@ def __call__(
316301
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
317302
)
318303

319-
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
304+
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
320305
image = image.to(dtype=image_embeds.dtype, device=device)
321306

322307
latents = self.movq.encode(image)["latents"]
323308
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
324309
self.scheduler.set_timesteps(num_inference_steps, device=device)
325310
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
326311
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
327-
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
328312
latents = self.prepare_latents(
329313
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
330314
)
@@ -379,13 +363,7 @@ def __call__(
379363
if output_type not in ["pt", "np", "pil"]:
380364
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
381365

382-
if output_type in ["np", "pil"]:
383-
image = image * 0.5 + 0.5
384-
image = image.clamp(0, 1)
385-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
386-
387-
if output_type == "pil":
388-
image = self.numpy_to_pil(image)
366+
image = self.image_processor.postprocess(image, output_type=output_type)
389367

390368
if not return_dict:
391369
return (image,)

src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...utils import deprecate, is_torch_xla_available, logging
2525
from ...utils.torch_utils import randn_tensor
2626
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27-
27+
from ...image_processor import VaeImageProcessor
2828

2929
if is_torch_xla_available():
3030
import torch_xla.core.xla_model as xm
@@ -76,27 +76,6 @@
7676
"""
7777

7878

79-
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
80-
def downscale_height_and_width(height, width, scale_factor=8):
81-
new_height = height // scale_factor**2
82-
if height % scale_factor**2 != 0:
83-
new_height += 1
84-
new_width = width // scale_factor**2
85-
if width % scale_factor**2 != 0:
86-
new_width += 1
87-
return new_height * scale_factor, new_width * scale_factor
88-
89-
90-
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
91-
def prepare_image(pil_image, w=512, h=512):
92-
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
93-
arr = np.array(pil_image.convert("RGB"))
94-
arr = arr.astype(np.float32) / 127.5 - 1
95-
arr = np.transpose(arr, [2, 0, 1])
96-
image = torch.from_numpy(arr).unsqueeze(0)
97-
return image
98-
99-
10079
class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
10180
"""
10281
Pipeline for image-to-image generation using Kandinsky
@@ -129,7 +108,13 @@ def __init__(
129108
scheduler=scheduler,
130109
movq=movq,
131110
)
132-
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
111+
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
112+
self.image_processor = VaeImageProcessor(
113+
vae_scale_factor = movq_scale_factor,
114+
vae_latent_channels = self.movq.config.latent_channels,
115+
resample = "bicubic",
116+
reducing_gap = 1,
117+
)
133118

134119
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
135120
def get_timesteps(self, num_inference_steps, strength, device):
@@ -319,15 +304,14 @@ def __call__(
319304
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
320305
)
321306

322-
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
307+
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
323308
image = image.to(dtype=image_embeds.dtype, device=device)
324309

325310
latents = self.movq.encode(image)["latents"]
326311
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
327312
self.scheduler.set_timesteps(num_inference_steps, device=device)
328313
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
329314
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
330-
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
331315
latents = self.prepare_latents(
332316
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
333317
)
@@ -383,24 +367,12 @@ def __call__(
383367
if XLA_AVAILABLE:
384368
xm.mark_step()
385369

386-
if output_type not in ["pt", "np", "pil", "latent"]:
387-
raise ValueError(
388-
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
389-
)
390-
391370
if not output_type == "latent":
392-
# post-processing
393371
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
394-
if output_type in ["np", "pil"]:
395-
image = image * 0.5 + 0.5
396-
image = image.clamp(0, 1)
397-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
398-
399-
if output_type == "pil":
400-
image = self.numpy_to_pil(image)
401372
else:
402373
image = latents
403-
374+
image = self.image_processor.postprocess(image, output_type=output_type)
375+
404376
# Offload all models
405377
self.maybe_free_model_hooks()
406378

src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from ...utils.torch_utils import randn_tensor
2020
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
21-
21+
from ...image_processor import VaeImageProcessor
2222

2323
if is_torch_xla_available():
2424
import torch_xla.core.xla_model as xm
@@ -53,24 +53,6 @@
5353
"""
5454

5555

56-
def downscale_height_and_width(height, width, scale_factor=8):
57-
new_height = height // scale_factor**2
58-
if height % scale_factor**2 != 0:
59-
new_height += 1
60-
new_width = width // scale_factor**2
61-
if width % scale_factor**2 != 0:
62-
new_width += 1
63-
return new_height * scale_factor, new_width * scale_factor
64-
65-
66-
def prepare_image(pil_image):
67-
arr = np.array(pil_image.convert("RGB"))
68-
arr = arr.astype(np.float32) / 127.5 - 1
69-
arr = np.transpose(arr, [2, 0, 1])
70-
image = torch.from_numpy(arr).unsqueeze(0)
71-
return image
72-
73-
7456
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
7557
model_cpu_offload_seq = "text_encoder->movq->unet->movq"
7658
_callback_tensor_inputs = [
@@ -94,6 +76,13 @@ def __init__(
9476
self.register_modules(
9577
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
9678
)
79+
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
80+
self.image_processor = VaeImageProcessor(
81+
vae_scale_factor = movq_scale_factor,
82+
vae_latent_channels = self.movq.config.latent_channels,
83+
resample = "bicubic",
84+
reducing_gap = 1,
85+
)
9786

9887
def get_timesteps(self, num_inference_steps, strength, device):
9988
# get the original timestep using init_timestep
@@ -566,7 +555,7 @@ def __call__(
566555
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
567556
)
568557

569-
image = torch.cat([prepare_image(i) for i in image], dim=0)
558+
image = torch.cat([self.image_processor.preprocess(i) for i in image], dim=0)
570559
image = image.to(dtype=prompt_embeds.dtype, device=device)
571560
# 4. Prepare timesteps
572561
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -630,22 +619,11 @@ def __call__(
630619
xm.mark_step()
631620

632621
# post-processing
633-
if output_type not in ["pt", "np", "pil", "latent"]:
634-
raise ValueError(
635-
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
636-
)
637622
if not output_type == "latent":
638623
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
639-
640-
if output_type in ["np", "pil"]:
641-
image = image * 0.5 + 0.5
642-
image = image.clamp(0, 1)
643-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
644-
645-
if output_type == "pil":
646-
image = self.numpy_to_pil(image)
647624
else:
648625
image = latents
626+
image = self.image_processor.postprocess(image, output_type=output_type)
649627

650628
self.maybe_free_model_hooks()
651629

0 commit comments

Comments
 (0)