1
- import numpy as np
2
1
import torch
3
- from PIL import Image
4
2
5
3
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
6
4
from invokeai .app .invocations .fields import (
17
15
from invokeai .backend .spandrel_image_to_image_model import SpandrelImageToImageModel
18
16
19
17
20
- def pil_to_tensor (image : Image .Image ) -> torch .Tensor :
21
- """Convert PIL Image to torch.Tensor.
22
-
23
- Args:
24
- image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
25
-
26
- Returns:
27
- torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
28
- """
29
- image_np = np .array (image )
30
- # (H, W, C) -> (C, H, W)
31
- image_np = np .transpose (image_np , (2 , 0 , 1 ))
32
- image_np = image_np / 255
33
- image_tensor = torch .from_numpy (image_np ).float ()
34
- # (C, H, W) -> (N, C, H, W)
35
- image_tensor = image_tensor .unsqueeze (0 )
36
- return image_tensor
37
-
38
-
39
- def tensor_to_pil (tensor : torch .Tensor ) -> Image .Image :
40
- """Convert torch.Tensor to PIL Image.
41
-
42
- Args:
43
- tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
44
-
45
- Returns:
46
- Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
47
- """
48
- # (N, C, H, W) -> (C, H, W)
49
- tensor = tensor .squeeze (0 )
50
- # (C, H, W) -> (H, W, C)
51
- tensor = tensor .permute (1 , 2 , 0 )
52
- tensor = tensor .clamp (0 , 1 )
53
- tensor = (tensor * 255 ).cpu ().detach ().numpy ().astype (np .uint8 )
54
- image = Image .fromarray (tensor )
55
- return image
56
-
57
-
58
18
@invocation ("upscale_spandrel" , title = "Upscale (spandrel)" , tags = ["upscale" ], category = "upscale" , version = "1.0.0" )
59
19
class UpscaleSpandrelInvocation (BaseInvocation , WithMetadata , WithBoard ):
60
20
"""Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel)."""
@@ -75,13 +35,13 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
75
35
assert isinstance (spandrel_model , SpandrelImageToImageModel )
76
36
77
37
# Prepare input image for inference.
78
- image_tensor = pil_to_tensor (image )
38
+ image_tensor = SpandrelImageToImageModel . pil_to_tensor (image )
79
39
image_tensor = image_tensor .to (device = spandrel_model .device , dtype = spandrel_model .dtype )
80
40
81
41
# Run inference.
82
42
image_tensor = spandrel_model .run (image_tensor )
83
43
84
44
# Convert the output tensor to a PIL image.
85
- pil_image = tensor_to_pil (image_tensor )
45
+ pil_image = SpandrelImageToImageModel . tensor_to_pil (image_tensor )
86
46
image_dto = context .images .save (image = pil_image )
87
47
return ImageOutput .build (image_dto )
0 commit comments