1
+ import numpy as np
1
2
import torch
3
+ from PIL import Image
4
+ from tqdm import tqdm
2
5
3
6
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
4
7
from invokeai .app .invocations .fields import (
13
16
from invokeai .app .invocations .primitives import ImageOutput
14
17
from invokeai .app .services .shared .invocation_context import InvocationContext
15
18
from invokeai .backend .spandrel_image_to_image_model import SpandrelImageToImageModel
19
+ from invokeai .backend .tiles .tiles import calc_tiles_min_overlap , merge_tiles_with_linear_blending
20
+ from invokeai .backend .tiles .utils import TBLR , Tile
16
21
17
22
18
- @invocation ("spandrel_image_to_image" , title = "Image-to-Image" , tags = ["upscale" ], category = "upscale" , version = "1.0 .0" )
23
+ @invocation ("spandrel_image_to_image" , title = "Image-to-Image" , tags = ["upscale" ], category = "upscale" , version = "1.1 .0" )
19
24
class SpandrelImageToImageInvocation (BaseInvocation , WithMetadata , WithBoard ):
20
25
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
21
26
@@ -25,25 +30,85 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
25
30
description = FieldDescriptions .spandrel_image_to_image_model ,
26
31
ui_type = UIType .SpandrelImageToImageModel ,
27
32
)
33
+ tile_size : int = InputField (
34
+ default = 512 , description = "The tile size for tiled image-to-image. Set to 0 to disable tiling."
35
+ )
36
+
37
+ def _scale_tile (self , tile : Tile , scale : int ) -> Tile :
38
+ return Tile (
39
+ coords = TBLR (
40
+ top = tile .coords .top * scale ,
41
+ bottom = tile .coords .bottom * scale ,
42
+ left = tile .coords .left * scale ,
43
+ right = tile .coords .right * scale ,
44
+ ),
45
+ overlap = TBLR (
46
+ top = tile .overlap .top * scale ,
47
+ bottom = tile .overlap .bottom * scale ,
48
+ left = tile .overlap .left * scale ,
49
+ right = tile .overlap .right * scale ,
50
+ ),
51
+ )
28
52
29
53
@torch .inference_mode ()
30
54
def invoke (self , context : InvocationContext ) -> ImageOutput :
31
- image = context .images .get_pil (self .image .image_name )
55
+ # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
56
+ # revisit this.
57
+ image = context .images .get_pil (self .image .image_name , mode = "RGB" )
58
+
59
+ # Compute the image tiles.
60
+ if self .tile_size > 0 :
61
+ min_overlap = 20
62
+ tiles = calc_tiles_min_overlap (
63
+ image_height = image .height ,
64
+ image_width = image .width ,
65
+ tile_height = self .tile_size ,
66
+ tile_width = self .tile_size ,
67
+ min_overlap = min_overlap ,
68
+ )
69
+ else :
70
+ # No tiling. Generate a single tile that covers the entire image.
71
+ min_overlap = 0
72
+ tiles = [
73
+ Tile (
74
+ coords = TBLR (top = 0 , bottom = image .height , left = 0 , right = image .width ),
75
+ overlap = TBLR (top = 0 , bottom = 0 , left = 0 , right = 0 ),
76
+ )
77
+ ]
78
+
79
+ # Prepare input image for inference.
80
+ image_tensor = SpandrelImageToImageModel .pil_to_tensor (image )
32
81
33
82
# Load the model.
34
83
spandrel_model_info = context .models .load (self .image_to_image_model )
35
84
85
+ # Run the model on each tile.
86
+ output_tiles : list [torch .Tensor ] = []
87
+ scale : int = 1
36
88
with spandrel_model_info as spandrel_model :
37
89
assert isinstance (spandrel_model , SpandrelImageToImageModel )
38
90
39
- # Prepare input image for inference.
40
- image_tensor = SpandrelImageToImageModel .pil_to_tensor (image )
91
+ # Scale the tiles for re-assembling the final image.
92
+ scale = spandrel_model .scale
93
+ scaled_tiles = [self ._scale_tile (tile , scale = scale ) for tile in tiles ]
94
+
41
95
image_tensor = image_tensor .to (device = spandrel_model .device , dtype = spandrel_model .dtype )
42
96
43
- # Run inference.
44
- image_tensor = spandrel_model .run (image_tensor )
97
+ for tile in tqdm (tiles , desc = "Upscaling Tiles" ):
98
+ output_tile = spandrel_model .run (
99
+ image_tensor [:, :, tile .coords .top : tile .coords .bottom , tile .coords .left : tile .coords .right ]
100
+ )
101
+ output_tiles .append (output_tile )
102
+
103
+ # Merge tiles into output image.
104
+ np_output_tiles = [np .array (SpandrelImageToImageModel .tensor_to_pil (tile )) for tile in output_tiles ]
105
+ _ , channels , height , width = image_tensor .shape
106
+ np_out_image = np .zeros ((height * scale , width * scale , channels ), dtype = np .uint8 )
107
+ merge_tiles_with_linear_blending (
108
+ dst_image = np_out_image , tiles = scaled_tiles , tile_images = np_output_tiles , blend_amount = min_overlap // 2
109
+ )
45
110
46
111
# Convert the output tensor to a PIL image.
47
- pil_image = SpandrelImageToImageModel . tensor_to_pil ( image_tensor )
112
+ pil_image = Image . fromarray ( np_out_image )
48
113
image_dto = context .images .save (image = pil_image )
49
114
return ImageOutput .build (image_dto )
0 commit comments