|
1 |
| -import numpy as np |
2 | 1 | import torch
|
3 |
| -from PIL import Image |
4 | 2 | from tqdm import tqdm
|
5 | 3 |
|
6 | 4 | from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
|
16 | 14 | from invokeai.app.invocations.primitives import ImageOutput
|
17 | 15 | from invokeai.app.services.shared.invocation_context import InvocationContext
|
18 | 16 | from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
19 |
| -from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending |
| 17 | +from invokeai.backend.tiles.tiles import calc_tiles_min_overlap |
20 | 18 | from invokeai.backend.tiles.utils import TBLR, Tile
|
21 | 19 |
|
22 | 20 |
|
@@ -50,6 +48,29 @@ def _scale_tile(self, tile: Tile, scale: int) -> Tile:
|
50 | 48 | ),
|
51 | 49 | )
|
52 | 50 |
|
| 51 | + def _merge_tiles(self, tiles: list[Tile], tile_tensors: list[torch.Tensor], out_tensor: torch.Tensor): |
| 52 | + """A simple tile merging algorithm. tile_tensors are merged into out_tensor. When adjacent tiles overlap, we |
| 53 | + split the overlap in half. No 'blending' is applied. |
| 54 | + """ |
| 55 | + # Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to |
| 56 | + # iterate over tiles left-to-right, top-to-bottom. |
| 57 | + tiles_and_tensors = list(zip(tiles, tile_tensors, strict=True)) |
| 58 | + tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.left) |
| 59 | + tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.top) |
| 60 | + |
| 61 | + for tile, tile_tensor in tiles_and_tensors: |
| 62 | + # We only keep half of the overlap on the top and left side of the tile. We do this in case there are edge |
| 63 | + # artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers it seems |
| 64 | + # unnecessary, but we may find a need in the future. |
| 65 | + top_overlap = tile.overlap.top // 2 |
| 66 | + left_overlap = tile.overlap.left // 2 |
| 67 | + out_tensor[ |
| 68 | + :, |
| 69 | + :, |
| 70 | + tile.coords.top + top_overlap : tile.coords.bottom, |
| 71 | + tile.coords.left + left_overlap : tile.coords.right, |
| 72 | + ] = tile_tensor[:, :, top_overlap:, left_overlap:] |
| 73 | + |
53 | 74 | @torch.inference_mode()
|
54 | 75 | def invoke(self, context: InvocationContext) -> ImageOutput:
|
55 | 76 | # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
@@ -100,15 +121,19 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
|
100 | 121 | )
|
101 | 122 | output_tiles.append(output_tile)
|
102 | 123 |
|
103 |
| - # Merge tiles into output image. |
104 |
| - np_output_tiles = [np.array(SpandrelImageToImageModel.tensor_to_pil(tile)) for tile in output_tiles] |
105 |
| - _, channels, height, width = image_tensor.shape |
106 |
| - np_out_image = np.zeros((height * scale, width * scale, channels), dtype=np.uint8) |
107 |
| - merge_tiles_with_linear_blending( |
108 |
| - dst_image=np_out_image, tiles=scaled_tiles, tile_images=np_output_tiles, blend_amount=min_overlap // 2 |
| 124 | + # TODO(ryand): There are opportunities to reduce peak VRAM utilization here if it becomes an issue: |
| 125 | + # - Keep the input tensor on the CPU. |
| 126 | + # - Move each tile to the GPU as it is processed. |
| 127 | + # - Move output tensors back to the CPU as they are produced, and merge them into the output tensor. |
| 128 | + |
| 129 | + # Merge the tiles to an output tensor. |
| 130 | + batch_size, channels, height, width = image_tensor.shape |
| 131 | + output_tensor = torch.zeros( |
| 132 | + (batch_size, channels, height * scale, width * scale), dtype=image_tensor.dtype, device=image_tensor.device |
109 | 133 | )
|
| 134 | + self._merge_tiles(scaled_tiles, output_tiles, output_tensor) |
110 | 135 |
|
111 | 136 | # Convert the output tensor to a PIL image.
|
112 |
| - pil_image = Image.fromarray(np_out_image) |
| 137 | + pil_image = SpandrelImageToImageModel.tensor_to_pil(output_tensor) |
113 | 138 | image_dto = context.images.save(image=pil_image)
|
114 | 139 | return ImageOutput.build(image_dto)
|
0 commit comments