|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +from PIL import Image |
| 4 | +from tqdm import tqdm |
| 5 | + |
| 6 | +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation |
| 7 | +from invokeai.app.invocations.fields import ( |
| 8 | + FieldDescriptions, |
| 9 | + ImageField, |
| 10 | + InputField, |
| 11 | + UIType, |
| 12 | + WithBoard, |
| 13 | + WithMetadata, |
| 14 | +) |
| 15 | +from invokeai.app.invocations.model import ModelIdentifierField |
| 16 | +from invokeai.app.invocations.primitives import ImageOutput |
| 17 | +from invokeai.app.services.session_processor.session_processor_common import CanceledException |
| 18 | +from invokeai.app.services.shared.invocation_context import InvocationContext |
| 19 | +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel |
| 20 | +from invokeai.backend.tiles.tiles import calc_tiles_min_overlap |
| 21 | +from invokeai.backend.tiles.utils import TBLR, Tile |
| 22 | + |
| 23 | + |
| 24 | +@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0") |
| 25 | +class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): |
| 26 | + """Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel).""" |
| 27 | + |
| 28 | + image: ImageField = InputField(description="The input image") |
| 29 | + image_to_image_model: ModelIdentifierField = InputField( |
| 30 | + title="Image-to-Image Model", |
| 31 | + description=FieldDescriptions.spandrel_image_to_image_model, |
| 32 | + ui_type=UIType.SpandrelImageToImageModel, |
| 33 | + ) |
| 34 | + tile_size: int = InputField( |
| 35 | + default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling." |
| 36 | + ) |
| 37 | + |
| 38 | + def _scale_tile(self, tile: Tile, scale: int) -> Tile: |
| 39 | + return Tile( |
| 40 | + coords=TBLR( |
| 41 | + top=tile.coords.top * scale, |
| 42 | + bottom=tile.coords.bottom * scale, |
| 43 | + left=tile.coords.left * scale, |
| 44 | + right=tile.coords.right * scale, |
| 45 | + ), |
| 46 | + overlap=TBLR( |
| 47 | + top=tile.overlap.top * scale, |
| 48 | + bottom=tile.overlap.bottom * scale, |
| 49 | + left=tile.overlap.left * scale, |
| 50 | + right=tile.overlap.right * scale, |
| 51 | + ), |
| 52 | + ) |
| 53 | + |
| 54 | + @torch.inference_mode() |
| 55 | + def invoke(self, context: InvocationContext) -> ImageOutput: |
| 56 | + # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to |
| 57 | + # revisit this. |
| 58 | + image = context.images.get_pil(self.image.image_name, mode="RGB") |
| 59 | + |
| 60 | + # Compute the image tiles. |
| 61 | + if self.tile_size > 0: |
| 62 | + min_overlap = 20 |
| 63 | + tiles = calc_tiles_min_overlap( |
| 64 | + image_height=image.height, |
| 65 | + image_width=image.width, |
| 66 | + tile_height=self.tile_size, |
| 67 | + tile_width=self.tile_size, |
| 68 | + min_overlap=min_overlap, |
| 69 | + ) |
| 70 | + else: |
| 71 | + # No tiling. Generate a single tile that covers the entire image. |
| 72 | + min_overlap = 0 |
| 73 | + tiles = [ |
| 74 | + Tile( |
| 75 | + coords=TBLR(top=0, bottom=image.height, left=0, right=image.width), |
| 76 | + overlap=TBLR(top=0, bottom=0, left=0, right=0), |
| 77 | + ) |
| 78 | + ] |
| 79 | + |
| 80 | + # Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate |
| 81 | + # over tiles left-to-right, top-to-bottom. |
| 82 | + tiles = sorted(tiles, key=lambda x: x.coords.left) |
| 83 | + tiles = sorted(tiles, key=lambda x: x.coords.top) |
| 84 | + |
| 85 | + # Prepare input image for inference. |
| 86 | + image_tensor = SpandrelImageToImageModel.pil_to_tensor(image) |
| 87 | + |
| 88 | + # Load the model. |
| 89 | + spandrel_model_info = context.models.load(self.image_to_image_model) |
| 90 | + |
| 91 | + # Run the model on each tile. |
| 92 | + with spandrel_model_info as spandrel_model: |
| 93 | + assert isinstance(spandrel_model, SpandrelImageToImageModel) |
| 94 | + |
| 95 | + # Scale the tiles for re-assembling the final image. |
| 96 | + scale = spandrel_model.scale |
| 97 | + scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles] |
| 98 | + |
| 99 | + # Prepare the output tensor. |
| 100 | + _, channels, height, width = image_tensor.shape |
| 101 | + output_tensor = torch.zeros( |
| 102 | + (height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu") |
| 103 | + ) |
| 104 | + |
| 105 | + image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) |
| 106 | + |
| 107 | + for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"): |
| 108 | + # Exit early if the invocation has been canceled. |
| 109 | + if context.util.is_canceled(): |
| 110 | + raise CanceledException |
| 111 | + |
| 112 | + # Extract the current tile from the input tensor. |
| 113 | + input_tile = image_tensor[ |
| 114 | + :, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right |
| 115 | + ].to(device=spandrel_model.device, dtype=spandrel_model.dtype) |
| 116 | + |
| 117 | + # Run the model on the tile. |
| 118 | + output_tile = spandrel_model.run(input_tile) |
| 119 | + |
| 120 | + # Convert the output tile into the output tensor's format. |
| 121 | + # (N, C, H, W) -> (C, H, W) |
| 122 | + output_tile = output_tile.squeeze(0) |
| 123 | + # (C, H, W) -> (H, W, C) |
| 124 | + output_tile = output_tile.permute(1, 2, 0) |
| 125 | + output_tile = output_tile.clamp(0, 1) |
| 126 | + output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu")) |
| 127 | + |
| 128 | + # Merge the output tile into the output tensor. |
| 129 | + # We only keep half of the overlap on the top and left side of the tile. We do this in case there are |
| 130 | + # edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers |
| 131 | + # it seems unnecessary, but we may find a need in the future. |
| 132 | + top_overlap = scaled_tile.overlap.top // 2 |
| 133 | + left_overlap = scaled_tile.overlap.left // 2 |
| 134 | + output_tensor[ |
| 135 | + scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom, |
| 136 | + scaled_tile.coords.left + left_overlap : scaled_tile.coords.right, |
| 137 | + :, |
| 138 | + ] = output_tile[top_overlap:, left_overlap:, :] |
| 139 | + |
| 140 | + # Convert the output tensor to a PIL image. |
| 141 | + np_image = output_tensor.detach().numpy().astype(np.uint8) |
| 142 | + pil_image = Image.fromarray(np_image) |
| 143 | + image_dto = context.images.save(image=pil_image) |
| 144 | + return ImageOutput.build(image_dto) |
0 commit comments