Skip to content

Commit ab77572

Browse files
committed
Add tiling support to the SpoandrelImageToImage node.
1 parent 650902d commit ab77572

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

invokeai/app/invocations/spandrel_image_to_image.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import numpy as np
12
import torch
3+
from PIL import Image
4+
from tqdm import tqdm
25

36
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
47
from invokeai.app.invocations.fields import (
@@ -13,9 +16,11 @@
1316
from invokeai.app.invocations.primitives import ImageOutput
1417
from invokeai.app.services.shared.invocation_context import InvocationContext
1518
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
19+
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending
20+
from invokeai.backend.tiles.utils import TBLR, Tile
1621

1722

18-
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0")
23+
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
1924
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
2025
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
2126

@@ -25,25 +30,85 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
2530
description=FieldDescriptions.spandrel_image_to_image_model,
2631
ui_type=UIType.SpandrelImageToImageModel,
2732
)
33+
tile_size: int = InputField(
34+
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
35+
)
36+
37+
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
38+
return Tile(
39+
coords=TBLR(
40+
top=tile.coords.top * scale,
41+
bottom=tile.coords.bottom * scale,
42+
left=tile.coords.left * scale,
43+
right=tile.coords.right * scale,
44+
),
45+
overlap=TBLR(
46+
top=tile.overlap.top * scale,
47+
bottom=tile.overlap.bottom * scale,
48+
left=tile.overlap.left * scale,
49+
right=tile.overlap.right * scale,
50+
),
51+
)
2852

2953
@torch.inference_mode()
3054
def invoke(self, context: InvocationContext) -> ImageOutput:
31-
image = context.images.get_pil(self.image.image_name)
55+
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
56+
# revisit this.
57+
image = context.images.get_pil(self.image.image_name, mode="RGB")
58+
59+
# Compute the image tiles.
60+
if self.tile_size > 0:
61+
min_overlap = 20
62+
tiles = calc_tiles_min_overlap(
63+
image_height=image.height,
64+
image_width=image.width,
65+
tile_height=self.tile_size,
66+
tile_width=self.tile_size,
67+
min_overlap=min_overlap,
68+
)
69+
else:
70+
# No tiling. Generate a single tile that covers the entire image.
71+
min_overlap = 0
72+
tiles = [
73+
Tile(
74+
coords=TBLR(top=0, bottom=image.height, left=0, right=image.width),
75+
overlap=TBLR(top=0, bottom=0, left=0, right=0),
76+
)
77+
]
78+
79+
# Prepare input image for inference.
80+
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
3281

3382
# Load the model.
3483
spandrel_model_info = context.models.load(self.image_to_image_model)
3584

85+
# Run the model on each tile.
86+
output_tiles: list[torch.Tensor] = []
87+
scale: int = 1
3688
with spandrel_model_info as spandrel_model:
3789
assert isinstance(spandrel_model, SpandrelImageToImageModel)
3890

39-
# Prepare input image for inference.
40-
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
91+
# Scale the tiles for re-assembling the final image.
92+
scale = spandrel_model.scale
93+
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
94+
4195
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
4296

43-
# Run inference.
44-
image_tensor = spandrel_model.run(image_tensor)
97+
for tile in tqdm(tiles, desc="Upscaling Tiles"):
98+
output_tile = spandrel_model.run(
99+
image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
100+
)
101+
output_tiles.append(output_tile)
102+
103+
# Merge tiles into output image.
104+
np_output_tiles = [np.array(SpandrelImageToImageModel.tensor_to_pil(tile)) for tile in output_tiles]
105+
_, channels, height, width = image_tensor.shape
106+
np_out_image = np.zeros((height * scale, width * scale, channels), dtype=np.uint8)
107+
merge_tiles_with_linear_blending(
108+
dst_image=np_out_image, tiles=scaled_tiles, tile_images=np_output_tiles, blend_amount=min_overlap // 2
109+
)
45110

46111
# Convert the output tensor to a PIL image.
47-
pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
112+
pil_image = Image.fromarray(np_out_image)
48113
image_dto = context.images.save(image=pil_image)
49114
return ImageOutput.build(image_dto)

invokeai/backend/spandrel_image_to_image_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def dtype(self) -> torch.dtype:
126126
"""The dtype of the underlying model."""
127127
return self._spandrel_model.dtype
128128

129+
@property
130+
def scale(self) -> int:
131+
"""The scale of the model (e.g. 1x, 2x, 4x, etc.)."""
132+
return self._spandrel_model.scale
133+
129134
def calc_size(self) -> int:
130135
"""Get size of the model in memory in bytes."""
131136
# HACK(ryand): Fix this issue with circular imports.

0 commit comments

Comments
 (0)