Skip to content

Commit 0580379

Browse files
committed
fix image processing when input is tensor
1 parent 479d9d2 commit 0580379

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import urllib.parse as ul
1919
import warnings
2020
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21+
import torch.nn.functional as F
2122

2223
import torch
2324
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
@@ -579,14 +580,22 @@ def _clean_caption(self, caption):
579580

580581
def prepare_image(
581582
self,
582-
image,
583-
width,
584-
height,
585-
device,
586-
dtype,
583+
image: PipelineImageInput,
584+
width: int,
585+
height: int,
586+
device: torch.device,
587+
dtype: torch.dtype,
587588
):
588589
if isinstance(image, torch.Tensor):
589-
pass
590+
if image.ndim == 3:
591+
image = image.unsqueeze(0)
592+
# Resize if current dimensions do not match target dimensions.
593+
if image.shape[2] != height or image.shape[3] != width:
594+
image = F.interpolate(image, size=(height, width), mode="bilinear",
595+
align_corners=False)
596+
597+
image = self.image_processor.preprocess(image, height=height, width=width)
598+
590599
else:
591600
image = self.image_processor.preprocess(image, height=height, width=width)
592601

0 commit comments

Comments
 (0)