Skip to content

Commit 95e9f53

Browse files
authored
Add tiling to SpandrelImageToImageInvocation (#6594)
## Summary Add tiling to the `SpandrelImageToImageInvocation` node so that it can process large images. Tiling enables this node to run on effectively any input image dimension. Of course, the computation time increases quadratically with the image dimension. Some profiling results on an RTX4090: - Input 1024x1024, 4x upscale, 4x UltraSharp ESRGAN: `13 secs`, `<4 GB VRAM` - Input 4096x4096, 4x upscale, 4x UltraSharop ESRGAN: `46 secs`, `<4 GB VRAM` - Input 4096x4096, 2x upscale, SwinIR: `165 secs`, `<5 GB VRAM` A lot of the time is spent PNG encoding the final image: - PNG encoding of a 16384x16384 image takes `83secs @ pil_compress_level=7`, `24secs @ pil_compress_level=1` Callout: If we want to start building workflows that pass large images between nodes, we are going to have to find a way to avoid the PNG encode/decode roundtrip that we are currently doing. As is, we will be incurring a huge penalty for every node that receives/produces a large image. ## QA Instructions - [x] Tested with tiling up to 4096x4096 -> 16384x16384. - [x] Test on images with an alpha channel (the alpha channel is dropped). - [x] Test on images with odd dimension. - [x] Test no tiling (`tile_size=0`) ## Merge Plan - [x] Merge #6556 first, and change the target branch to `main`. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
2 parents 7ad32dc + 6b0ca88 commit 95e9f53

File tree

2 files changed

+107
-7
lines changed

2 files changed

+107
-7
lines changed
Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import numpy as np
12
import torch
3+
from PIL import Image
4+
from tqdm import tqdm
25

36
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
47
from invokeai.app.invocations.fields import (
@@ -11,11 +14,14 @@
1114
)
1215
from invokeai.app.invocations.model import ModelIdentifierField
1316
from invokeai.app.invocations.primitives import ImageOutput
17+
from invokeai.app.services.session_processor.session_processor_common import CanceledException
1418
from invokeai.app.services.shared.invocation_context import InvocationContext
1519
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
20+
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
21+
from invokeai.backend.tiles.utils import TBLR, Tile
1622

1723

18-
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0")
24+
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
1925
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
2026
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
2127

@@ -25,25 +31,114 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
2531
description=FieldDescriptions.spandrel_image_to_image_model,
2632
ui_type=UIType.SpandrelImageToImageModel,
2733
)
34+
tile_size: int = InputField(
35+
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
36+
)
37+
38+
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
39+
return Tile(
40+
coords=TBLR(
41+
top=tile.coords.top * scale,
42+
bottom=tile.coords.bottom * scale,
43+
left=tile.coords.left * scale,
44+
right=tile.coords.right * scale,
45+
),
46+
overlap=TBLR(
47+
top=tile.overlap.top * scale,
48+
bottom=tile.overlap.bottom * scale,
49+
left=tile.overlap.left * scale,
50+
right=tile.overlap.right * scale,
51+
),
52+
)
2853

2954
@torch.inference_mode()
3055
def invoke(self, context: InvocationContext) -> ImageOutput:
31-
image = context.images.get_pil(self.image.image_name)
56+
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
57+
# revisit this.
58+
image = context.images.get_pil(self.image.image_name, mode="RGB")
59+
60+
# Compute the image tiles.
61+
if self.tile_size > 0:
62+
min_overlap = 20
63+
tiles = calc_tiles_min_overlap(
64+
image_height=image.height,
65+
image_width=image.width,
66+
tile_height=self.tile_size,
67+
tile_width=self.tile_size,
68+
min_overlap=min_overlap,
69+
)
70+
else:
71+
# No tiling. Generate a single tile that covers the entire image.
72+
min_overlap = 0
73+
tiles = [
74+
Tile(
75+
coords=TBLR(top=0, bottom=image.height, left=0, right=image.width),
76+
overlap=TBLR(top=0, bottom=0, left=0, right=0),
77+
)
78+
]
79+
80+
# Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
81+
# over tiles left-to-right, top-to-bottom.
82+
tiles = sorted(tiles, key=lambda x: x.coords.left)
83+
tiles = sorted(tiles, key=lambda x: x.coords.top)
84+
85+
# Prepare input image for inference.
86+
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
3287

3388
# Load the model.
3489
spandrel_model_info = context.models.load(self.image_to_image_model)
3590

91+
# Run the model on each tile.
3692
with spandrel_model_info as spandrel_model:
3793
assert isinstance(spandrel_model, SpandrelImageToImageModel)
3894

39-
# Prepare input image for inference.
40-
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
95+
# Scale the tiles for re-assembling the final image.
96+
scale = spandrel_model.scale
97+
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
98+
99+
# Prepare the output tensor.
100+
_, channels, height, width = image_tensor.shape
101+
output_tensor = torch.zeros(
102+
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
103+
)
104+
41105
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
42106

43-
# Run inference.
44-
image_tensor = spandrel_model.run(image_tensor)
107+
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
108+
# Exit early if the invocation has been canceled.
109+
if context.util.is_canceled():
110+
raise CanceledException
111+
112+
# Extract the current tile from the input tensor.
113+
input_tile = image_tensor[
114+
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
115+
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
116+
117+
# Run the model on the tile.
118+
output_tile = spandrel_model.run(input_tile)
119+
120+
# Convert the output tile into the output tensor's format.
121+
# (N, C, H, W) -> (C, H, W)
122+
output_tile = output_tile.squeeze(0)
123+
# (C, H, W) -> (H, W, C)
124+
output_tile = output_tile.permute(1, 2, 0)
125+
output_tile = output_tile.clamp(0, 1)
126+
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
127+
128+
# Merge the output tile into the output tensor.
129+
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
130+
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
131+
# it seems unnecessary, but we may find a need in the future.
132+
top_overlap = scaled_tile.overlap.top // 2
133+
left_overlap = scaled_tile.overlap.left // 2
134+
output_tensor[
135+
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
136+
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
137+
:,
138+
] = output_tile[top_overlap:, left_overlap:, :]
45139

46140
# Convert the output tensor to a PIL image.
47-
pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
141+
np_image = output_tensor.detach().numpy().astype(np.uint8)
142+
pil_image = Image.fromarray(np_image)
48143
image_dto = context.images.save(image=pil_image)
49144
return ImageOutput.build(image_dto)

invokeai/backend/spandrel_image_to_image_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def dtype(self) -> torch.dtype:
126126
"""The dtype of the underlying model."""
127127
return self._spandrel_model.dtype
128128

129+
@property
130+
def scale(self) -> int:
131+
"""The scale of the model (e.g. 1x, 2x, 4x, etc.)."""
132+
return self._spandrel_model.scale
133+
129134
def calc_size(self) -> int:
130135
"""Get size of the model in memory in bytes."""
131136
# HACK(ryand): Fix this issue with circular imports.

0 commit comments

Comments
 (0)