Skip to content

Commit d0d2955

Browse files
committed
Reduce peak VRAM utilization of SpandrelImageToImageInvocation.
1 parent d868d5d commit d0d2955

File tree

1 file changed

+43
-43
lines changed

1 file changed

+43
-43
lines changed

invokeai/app/invocations/spandrel_image_to_image.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import numpy as np
12
import torch
3+
from PIL import Image
24
from tqdm import tqdm
35

46
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@@ -48,29 +50,6 @@ def _scale_tile(self, tile: Tile, scale: int) -> Tile:
4850
),
4951
)
5052

51-
def _merge_tiles(self, tiles: list[Tile], tile_tensors: list[torch.Tensor], out_tensor: torch.Tensor):
52-
"""A simple tile merging algorithm. tile_tensors are merged into out_tensor. When adjacent tiles overlap, we
53-
split the overlap in half. No 'blending' is applied.
54-
"""
55-
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
56-
# iterate over tiles left-to-right, top-to-bottom.
57-
tiles_and_tensors = list(zip(tiles, tile_tensors, strict=True))
58-
tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.left)
59-
tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.top)
60-
61-
for tile, tile_tensor in tiles_and_tensors:
62-
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are edge
63-
# artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers it seems
64-
# unnecessary, but we may find a need in the future.
65-
top_overlap = tile.overlap.top // 2
66-
left_overlap = tile.overlap.left // 2
67-
out_tensor[
68-
:,
69-
:,
70-
tile.coords.top + top_overlap : tile.coords.bottom,
71-
tile.coords.left + left_overlap : tile.coords.right,
72-
] = tile_tensor[:, :, top_overlap:, left_overlap:]
73-
7453
@torch.inference_mode()
7554
def invoke(self, context: InvocationContext) -> ImageOutput:
7655
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
@@ -97,43 +76,64 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
9776
)
9877
]
9978

79+
# Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
80+
# over tiles left-to-right, top-to-bottom.
81+
tiles = sorted(tiles, key=lambda x: x.coords.left)
82+
tiles = sorted(tiles, key=lambda x: x.coords.top)
83+
10084
# Prepare input image for inference.
10185
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
10286

10387
# Load the model.
10488
spandrel_model_info = context.models.load(self.image_to_image_model)
10589

10690
# Run the model on each tile.
107-
output_tiles: list[torch.Tensor] = []
108-
scale: int = 1
10991
with spandrel_model_info as spandrel_model:
11092
assert isinstance(spandrel_model, SpandrelImageToImageModel)
11193

11294
# Scale the tiles for re-assembling the final image.
11395
scale = spandrel_model.scale
11496
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
11597

116-
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
117-
118-
for tile in tqdm(tiles, desc="Upscaling Tiles"):
119-
output_tile = spandrel_model.run(
120-
image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
121-
)
122-
output_tiles.append(output_tile)
98+
# Prepare the output tensor.
99+
_, channels, height, width = image_tensor.shape
100+
output_tensor = torch.zeros(
101+
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
102+
)
123103

124-
# TODO(ryand): There are opportunities to reduce peak VRAM utilization here if it becomes an issue:
125-
# - Keep the input tensor on the CPU.
126-
# - Move each tile to the GPU as it is processed.
127-
# - Move output tensors back to the CPU as they are produced, and merge them into the output tensor.
104+
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
128105

129-
# Merge the tiles to an output tensor.
130-
batch_size, channels, height, width = image_tensor.shape
131-
output_tensor = torch.zeros(
132-
(batch_size, channels, height * scale, width * scale), dtype=image_tensor.dtype, device=image_tensor.device
133-
)
134-
self._merge_tiles(scaled_tiles, output_tiles, output_tensor)
106+
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
107+
# Extract the current tile from the input tensor.
108+
input_tile = image_tensor[
109+
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
110+
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
111+
112+
# Run the model on the tile.
113+
output_tile = spandrel_model.run(input_tile)
114+
115+
# Convert the output tile into the output tensor's format.
116+
# (N, C, H, W) -> (C, H, W)
117+
output_tile = output_tile.squeeze(0)
118+
# (C, H, W) -> (H, W, C)
119+
output_tile = output_tile.permute(1, 2, 0)
120+
output_tile = output_tile.clamp(0, 1)
121+
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
122+
123+
# Merge the output tile into the output tensor.
124+
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
125+
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
126+
# it seems unnecessary, but we may find a need in the future.
127+
top_overlap = scaled_tile.overlap.top // 2
128+
left_overlap = scaled_tile.overlap.left // 2
129+
output_tensor[
130+
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
131+
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
132+
:,
133+
] = output_tile[top_overlap:, left_overlap:, :]
135134

136135
# Convert the output tensor to a PIL image.
137-
pil_image = SpandrelImageToImageModel.tensor_to_pil(output_tensor)
136+
np_image = output_tensor.detach().numpy().astype(np.uint8)
137+
pil_image = Image.fromarray(np_image)
138138
image_dto = context.images.save(image=pil_image)
139139
return ImageOutput.build(image_dto)

0 commit comments

Comments
 (0)