1
+ import numpy as np
1
2
import torch
3
+ from PIL import Image
2
4
from tqdm import tqdm
3
5
4
6
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
@@ -48,29 +50,6 @@ def _scale_tile(self, tile: Tile, scale: int) -> Tile:
48
50
),
49
51
)
50
52
51
- def _merge_tiles (self , tiles : list [Tile ], tile_tensors : list [torch .Tensor ], out_tensor : torch .Tensor ):
52
- """A simple tile merging algorithm. tile_tensors are merged into out_tensor. When adjacent tiles overlap, we
53
- split the overlap in half. No 'blending' is applied.
54
- """
55
- # Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
56
- # iterate over tiles left-to-right, top-to-bottom.
57
- tiles_and_tensors = list (zip (tiles , tile_tensors , strict = True ))
58
- tiles_and_tensors = sorted (tiles_and_tensors , key = lambda x : x [0 ].coords .left )
59
- tiles_and_tensors = sorted (tiles_and_tensors , key = lambda x : x [0 ].coords .top )
60
-
61
- for tile , tile_tensor in tiles_and_tensors :
62
- # We only keep half of the overlap on the top and left side of the tile. We do this in case there are edge
63
- # artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers it seems
64
- # unnecessary, but we may find a need in the future.
65
- top_overlap = tile .overlap .top // 2
66
- left_overlap = tile .overlap .left // 2
67
- out_tensor [
68
- :,
69
- :,
70
- tile .coords .top + top_overlap : tile .coords .bottom ,
71
- tile .coords .left + left_overlap : tile .coords .right ,
72
- ] = tile_tensor [:, :, top_overlap :, left_overlap :]
73
-
74
53
@torch .inference_mode ()
75
54
def invoke (self , context : InvocationContext ) -> ImageOutput :
76
55
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
@@ -97,43 +76,64 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
97
76
)
98
77
]
99
78
79
+ # Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
80
+ # over tiles left-to-right, top-to-bottom.
81
+ tiles = sorted (tiles , key = lambda x : x .coords .left )
82
+ tiles = sorted (tiles , key = lambda x : x .coords .top )
83
+
100
84
# Prepare input image for inference.
101
85
image_tensor = SpandrelImageToImageModel .pil_to_tensor (image )
102
86
103
87
# Load the model.
104
88
spandrel_model_info = context .models .load (self .image_to_image_model )
105
89
106
90
# Run the model on each tile.
107
- output_tiles : list [torch .Tensor ] = []
108
- scale : int = 1
109
91
with spandrel_model_info as spandrel_model :
110
92
assert isinstance (spandrel_model , SpandrelImageToImageModel )
111
93
112
94
# Scale the tiles for re-assembling the final image.
113
95
scale = spandrel_model .scale
114
96
scaled_tiles = [self ._scale_tile (tile , scale = scale ) for tile in tiles ]
115
97
116
- image_tensor = image_tensor .to (device = spandrel_model .device , dtype = spandrel_model .dtype )
117
-
118
- for tile in tqdm (tiles , desc = "Upscaling Tiles" ):
119
- output_tile = spandrel_model .run (
120
- image_tensor [:, :, tile .coords .top : tile .coords .bottom , tile .coords .left : tile .coords .right ]
121
- )
122
- output_tiles .append (output_tile )
98
+ # Prepare the output tensor.
99
+ _ , channels , height , width = image_tensor .shape
100
+ output_tensor = torch .zeros (
101
+ (height * scale , width * scale , channels ), dtype = torch .uint8 , device = torch .device ("cpu" )
102
+ )
123
103
124
- # TODO(ryand): There are opportunities to reduce peak VRAM utilization here if it becomes an issue:
125
- # - Keep the input tensor on the CPU.
126
- # - Move each tile to the GPU as it is processed.
127
- # - Move output tensors back to the CPU as they are produced, and merge them into the output tensor.
104
+ image_tensor = image_tensor .to (device = spandrel_model .device , dtype = spandrel_model .dtype )
128
105
129
- # Merge the tiles to an output tensor.
130
- batch_size , channels , height , width = image_tensor .shape
131
- output_tensor = torch .zeros (
132
- (batch_size , channels , height * scale , width * scale ), dtype = image_tensor .dtype , device = image_tensor .device
133
- )
134
- self ._merge_tiles (scaled_tiles , output_tiles , output_tensor )
106
+ for tile , scaled_tile in tqdm (list (zip (tiles , scaled_tiles , strict = True )), desc = "Upscaling Tiles" ):
107
+ # Extract the current tile from the input tensor.
108
+ input_tile = image_tensor [
109
+ :, :, tile .coords .top : tile .coords .bottom , tile .coords .left : tile .coords .right
110
+ ].to (device = spandrel_model .device , dtype = spandrel_model .dtype )
111
+
112
+ # Run the model on the tile.
113
+ output_tile = spandrel_model .run (input_tile )
114
+
115
+ # Convert the output tile into the output tensor's format.
116
+ # (N, C, H, W) -> (C, H, W)
117
+ output_tile = output_tile .squeeze (0 )
118
+ # (C, H, W) -> (H, W, C)
119
+ output_tile = output_tile .permute (1 , 2 , 0 )
120
+ output_tile = output_tile .clamp (0 , 1 )
121
+ output_tile = (output_tile * 255 ).to (dtype = torch .uint8 , device = torch .device ("cpu" ))
122
+
123
+ # Merge the output tile into the output tensor.
124
+ # We only keep half of the overlap on the top and left side of the tile. We do this in case there are
125
+ # edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
126
+ # it seems unnecessary, but we may find a need in the future.
127
+ top_overlap = scaled_tile .overlap .top // 2
128
+ left_overlap = scaled_tile .overlap .left // 2
129
+ output_tensor [
130
+ scaled_tile .coords .top + top_overlap : scaled_tile .coords .bottom ,
131
+ scaled_tile .coords .left + left_overlap : scaled_tile .coords .right ,
132
+ :,
133
+ ] = output_tile [top_overlap :, left_overlap :, :]
135
134
136
135
# Convert the output tensor to a PIL image.
137
- pil_image = SpandrelImageToImageModel .tensor_to_pil (output_tensor )
136
+ np_image = output_tensor .detach ().numpy ().astype (np .uint8 )
137
+ pil_image = Image .fromarray (np_image )
138
138
image_dto = context .images .save (image = pil_image )
139
139
return ImageOutput .build (image_dto )
0 commit comments