Skip to content

Commit 6161aa7

Browse files
committed
Move pil_to_tensor() and tensor_to_pil() utilities to the SpandrelImageToImage class.
1 parent 1ab20f4 commit 6161aa7

File tree

2 files changed

+47
-43
lines changed

2 files changed

+47
-43
lines changed
Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import numpy as np
21
import torch
3-
from PIL import Image
42

53
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
64
from invokeai.app.invocations.fields import (
@@ -17,44 +15,6 @@
1715
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
1816

1917

20-
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
21-
"""Convert PIL Image to torch.Tensor.
22-
23-
Args:
24-
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
25-
26-
Returns:
27-
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
28-
"""
29-
image_np = np.array(image)
30-
# (H, W, C) -> (C, H, W)
31-
image_np = np.transpose(image_np, (2, 0, 1))
32-
image_np = image_np / 255
33-
image_tensor = torch.from_numpy(image_np).float()
34-
# (C, H, W) -> (N, C, H, W)
35-
image_tensor = image_tensor.unsqueeze(0)
36-
return image_tensor
37-
38-
39-
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
40-
"""Convert torch.Tensor to PIL Image.
41-
42-
Args:
43-
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
44-
45-
Returns:
46-
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
47-
"""
48-
# (N, C, H, W) -> (C, H, W)
49-
tensor = tensor.squeeze(0)
50-
# (C, H, W) -> (H, W, C)
51-
tensor = tensor.permute(1, 2, 0)
52-
tensor = tensor.clamp(0, 1)
53-
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
54-
image = Image.fromarray(tensor)
55-
return image
56-
57-
5818
@invocation("upscale_spandrel", title="Upscale (spandrel)", tags=["upscale"], category="upscale", version="1.0.0")
5919
class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard):
6020
"""Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel)."""
@@ -75,13 +35,13 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
7535
assert isinstance(spandrel_model, SpandrelImageToImageModel)
7636

7737
# Prepare input image for inference.
78-
image_tensor = pil_to_tensor(image)
38+
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
7939
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
8040

8141
# Run inference.
8242
image_tensor = spandrel_model.run(image_tensor)
8343

8444
# Convert the output tensor to a PIL image.
85-
pil_image = tensor_to_pil(image_tensor)
45+
pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
8646
image_dto = context.images.save(image=pil_image)
8747
return ImageOutput.build(image_dto)

invokeai/backend/spandrel_image_to_image_model.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from pathlib import Path
22
from typing import Any, Optional
33

4+
import numpy as np
45
import torch
6+
from PIL import Image
57
from spandrel import ImageModelDescriptor, ModelLoader
68

79
from invokeai.backend.raw_model import RawModel
@@ -16,8 +18,50 @@ class SpandrelImageToImageModel(RawModel):
1618
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
1719
self._spandrel_model = spandrel_model
1820

21+
@staticmethod
22+
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
23+
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
24+
25+
Args:
26+
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
27+
28+
Returns:
29+
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
30+
"""
31+
image_np = np.array(image)
32+
# (H, W, C) -> (C, H, W)
33+
image_np = np.transpose(image_np, (2, 0, 1))
34+
image_np = image_np / 255
35+
image_tensor = torch.from_numpy(image_np).float()
36+
# (C, H, W) -> (N, C, H, W)
37+
image_tensor = image_tensor.unsqueeze(0)
38+
return image_tensor
39+
40+
@staticmethod
41+
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
42+
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
43+
44+
Args:
45+
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
46+
47+
Returns:
48+
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
49+
"""
50+
# (N, C, H, W) -> (C, H, W)
51+
tensor = tensor.squeeze(0)
52+
# (C, H, W) -> (H, W, C)
53+
tensor = tensor.permute(1, 2, 0)
54+
tensor = tensor.clamp(0, 1)
55+
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
56+
image = Image.fromarray(tensor)
57+
return image
58+
1959
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
20-
"""Run the image-to-image model."""
60+
"""Run the image-to-image model.
61+
62+
Args:
63+
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
64+
"""
2165
return self._spandrel_model(image_tensor)
2266

2367
@classmethod

0 commit comments

Comments
 (0)