Skip to content

Commit 710dc6b

Browse files
authored
Merge branch 'main' into stalker7779/backend_base
2 parents 2ef3b49 + 0583101 commit 710dc6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+4780
-1820
lines changed

invokeai/app/api/routers/images.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -233,21 +233,14 @@ async def get_image_workflow(
233233
)
234234
async def get_image_full(
235235
image_name: str = Path(description="The name of full-resolution image file to get"),
236-
) -> FileResponse:
236+
) -> Response:
237237
"""Gets a full-resolution image file"""
238238

239239
try:
240240
path = ApiDependencies.invoker.services.images.get_path(image_name)
241-
242-
if not ApiDependencies.invoker.services.images.validate_path(path):
243-
raise HTTPException(status_code=404)
244-
245-
response = FileResponse(
246-
path,
247-
media_type="image/png",
248-
filename=image_name,
249-
content_disposition_type="inline",
250-
)
241+
with open(path, "rb") as f:
242+
content = f.read()
243+
response = Response(content, media_type="image/png")
251244
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
252245
return response
253246
except Exception:
@@ -268,15 +261,14 @@ async def get_image_full(
268261
)
269262
async def get_image_thumbnail(
270263
image_name: str = Path(description="The name of thumbnail image file to get"),
271-
) -> FileResponse:
264+
) -> Response:
272265
"""Gets a thumbnail image file"""
273266

274267
try:
275268
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
276-
if not ApiDependencies.invoker.services.images.validate_path(path):
277-
raise HTTPException(status_code=404)
278-
279-
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
269+
with open(path, "rb") as f:
270+
content = f.read()
271+
response = Response(content, media_type="image/webp")
280272
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
281273
return response
282274
except Exception:

invokeai/app/api_app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def find_port(port: int) -> int:
161161
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
162162
# https://github.com/WaylonWalker
163163
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
164+
s.settimeout(1)
164165
if s.connect_ex(("localhost", port)) == 0:
165166
return find_port(port=port + 1)
166167
else:

invokeai/app/invocations/fields.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
4848
ControlNetModel = "ControlNetModelField"
4949
IPAdapterModel = "IPAdapterModelField"
5050
T2IAdapterModel = "T2IAdapterModelField"
51+
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
5152
# endregion
5253

5354
# region Misc Field Types
@@ -134,6 +135,7 @@ class FieldDescriptions:
134135
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
135136
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
136137
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
138+
spandrel_image_to_image_model = "Image-to-Image model"
137139
lora_weight = "The weight at which the LoRA is applied to each model"
138140
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
139141
raw_prompt = "Raw prompt text (no parsing)"
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import numpy as np
2+
import torch
3+
from PIL import Image
4+
from tqdm import tqdm
5+
6+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
7+
from invokeai.app.invocations.fields import (
8+
FieldDescriptions,
9+
ImageField,
10+
InputField,
11+
UIType,
12+
WithBoard,
13+
WithMetadata,
14+
)
15+
from invokeai.app.invocations.model import ModelIdentifierField
16+
from invokeai.app.invocations.primitives import ImageOutput
17+
from invokeai.app.services.session_processor.session_processor_common import CanceledException
18+
from invokeai.app.services.shared.invocation_context import InvocationContext
19+
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
20+
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
21+
from invokeai.backend.tiles.utils import TBLR, Tile
22+
23+
24+
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
25+
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
26+
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
27+
28+
image: ImageField = InputField(description="The input image")
29+
image_to_image_model: ModelIdentifierField = InputField(
30+
title="Image-to-Image Model",
31+
description=FieldDescriptions.spandrel_image_to_image_model,
32+
ui_type=UIType.SpandrelImageToImageModel,
33+
)
34+
tile_size: int = InputField(
35+
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
36+
)
37+
38+
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
39+
return Tile(
40+
coords=TBLR(
41+
top=tile.coords.top * scale,
42+
bottom=tile.coords.bottom * scale,
43+
left=tile.coords.left * scale,
44+
right=tile.coords.right * scale,
45+
),
46+
overlap=TBLR(
47+
top=tile.overlap.top * scale,
48+
bottom=tile.overlap.bottom * scale,
49+
left=tile.overlap.left * scale,
50+
right=tile.overlap.right * scale,
51+
),
52+
)
53+
54+
@torch.inference_mode()
55+
def invoke(self, context: InvocationContext) -> ImageOutput:
56+
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
57+
# revisit this.
58+
image = context.images.get_pil(self.image.image_name, mode="RGB")
59+
60+
# Compute the image tiles.
61+
if self.tile_size > 0:
62+
min_overlap = 20
63+
tiles = calc_tiles_min_overlap(
64+
image_height=image.height,
65+
image_width=image.width,
66+
tile_height=self.tile_size,
67+
tile_width=self.tile_size,
68+
min_overlap=min_overlap,
69+
)
70+
else:
71+
# No tiling. Generate a single tile that covers the entire image.
72+
min_overlap = 0
73+
tiles = [
74+
Tile(
75+
coords=TBLR(top=0, bottom=image.height, left=0, right=image.width),
76+
overlap=TBLR(top=0, bottom=0, left=0, right=0),
77+
)
78+
]
79+
80+
# Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
81+
# over tiles left-to-right, top-to-bottom.
82+
tiles = sorted(tiles, key=lambda x: x.coords.left)
83+
tiles = sorted(tiles, key=lambda x: x.coords.top)
84+
85+
# Prepare input image for inference.
86+
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
87+
88+
# Load the model.
89+
spandrel_model_info = context.models.load(self.image_to_image_model)
90+
91+
# Run the model on each tile.
92+
with spandrel_model_info as spandrel_model:
93+
assert isinstance(spandrel_model, SpandrelImageToImageModel)
94+
95+
# Scale the tiles for re-assembling the final image.
96+
scale = spandrel_model.scale
97+
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
98+
99+
# Prepare the output tensor.
100+
_, channels, height, width = image_tensor.shape
101+
output_tensor = torch.zeros(
102+
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
103+
)
104+
105+
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
106+
107+
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
108+
# Exit early if the invocation has been canceled.
109+
if context.util.is_canceled():
110+
raise CanceledException
111+
112+
# Extract the current tile from the input tensor.
113+
input_tile = image_tensor[
114+
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
115+
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
116+
117+
# Run the model on the tile.
118+
output_tile = spandrel_model.run(input_tile)
119+
120+
# Convert the output tile into the output tensor's format.
121+
# (N, C, H, W) -> (C, H, W)
122+
output_tile = output_tile.squeeze(0)
123+
# (C, H, W) -> (H, W, C)
124+
output_tile = output_tile.permute(1, 2, 0)
125+
output_tile = output_tile.clamp(0, 1)
126+
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
127+
128+
# Merge the output tile into the output tensor.
129+
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
130+
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
131+
# it seems unnecessary, but we may find a need in the future.
132+
top_overlap = scaled_tile.overlap.top // 2
133+
left_overlap = scaled_tile.overlap.left // 2
134+
output_tensor[
135+
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
136+
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
137+
:,
138+
] = output_tile[top_overlap:, left_overlap:, :]
139+
140+
# Convert the output tensor to a PIL image.
141+
np_image = output_tensor.detach().numpy().astype(np.uint8)
142+
pil_image = Image.fromarray(np_image)
143+
image_dto = context.images.save(image=pil_image)
144+
return ImageOutput.build(image_dto)

0 commit comments

Comments
 (0)