Skip to content

Commit d868d5d

Browse files
committed
Make SpandrelImageToImage tiling much faster.
1 parent ab77572 commit d868d5d

File tree

1 file changed

+35
-10
lines changed

1 file changed

+35
-10
lines changed

invokeai/app/invocations/spandrel_image_to_image.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import numpy as np
21
import torch
3-
from PIL import Image
42
from tqdm import tqdm
53

64
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@@ -16,7 +14,7 @@
1614
from invokeai.app.invocations.primitives import ImageOutput
1715
from invokeai.app.services.shared.invocation_context import InvocationContext
1816
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
19-
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending
17+
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
2018
from invokeai.backend.tiles.utils import TBLR, Tile
2119

2220

@@ -50,6 +48,29 @@ def _scale_tile(self, tile: Tile, scale: int) -> Tile:
5048
),
5149
)
5250

51+
def _merge_tiles(self, tiles: list[Tile], tile_tensors: list[torch.Tensor], out_tensor: torch.Tensor):
52+
"""A simple tile merging algorithm. tile_tensors are merged into out_tensor. When adjacent tiles overlap, we
53+
split the overlap in half. No 'blending' is applied.
54+
"""
55+
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
56+
# iterate over tiles left-to-right, top-to-bottom.
57+
tiles_and_tensors = list(zip(tiles, tile_tensors, strict=True))
58+
tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.left)
59+
tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.top)
60+
61+
for tile, tile_tensor in tiles_and_tensors:
62+
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are edge
63+
# artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers it seems
64+
# unnecessary, but we may find a need in the future.
65+
top_overlap = tile.overlap.top // 2
66+
left_overlap = tile.overlap.left // 2
67+
out_tensor[
68+
:,
69+
:,
70+
tile.coords.top + top_overlap : tile.coords.bottom,
71+
tile.coords.left + left_overlap : tile.coords.right,
72+
] = tile_tensor[:, :, top_overlap:, left_overlap:]
73+
5374
@torch.inference_mode()
5475
def invoke(self, context: InvocationContext) -> ImageOutput:
5576
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
@@ -100,15 +121,19 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
100121
)
101122
output_tiles.append(output_tile)
102123

103-
# Merge tiles into output image.
104-
np_output_tiles = [np.array(SpandrelImageToImageModel.tensor_to_pil(tile)) for tile in output_tiles]
105-
_, channels, height, width = image_tensor.shape
106-
np_out_image = np.zeros((height * scale, width * scale, channels), dtype=np.uint8)
107-
merge_tiles_with_linear_blending(
108-
dst_image=np_out_image, tiles=scaled_tiles, tile_images=np_output_tiles, blend_amount=min_overlap // 2
124+
# TODO(ryand): There are opportunities to reduce peak VRAM utilization here if it becomes an issue:
125+
# - Keep the input tensor on the CPU.
126+
# - Move each tile to the GPU as it is processed.
127+
# - Move output tensors back to the CPU as they are produced, and merge them into the output tensor.
128+
129+
# Merge the tiles to an output tensor.
130+
batch_size, channels, height, width = image_tensor.shape
131+
output_tensor = torch.zeros(
132+
(batch_size, channels, height * scale, width * scale), dtype=image_tensor.dtype, device=image_tensor.device
109133
)
134+
self._merge_tiles(scaled_tiles, output_tiles, output_tensor)
110135

111136
# Convert the output tensor to a PIL image.
112-
pil_image = Image.fromarray(np_out_image)
137+
pil_image = SpandrelImageToImageModel.tensor_to_pil(output_tensor)
113138
image_dto = context.images.save(image=pil_image)
114139
return ImageOutput.build(image_dto)

0 commit comments

Comments
 (0)