Skip to content

Commit ecbff2a

Browse files
committed
Whoops... forgot to commit this file.
1 parent 0ce6ec6 commit ecbff2a

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
3+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
4+
from invokeai.app.invocations.fields import (
5+
FieldDescriptions,
6+
ImageField,
7+
InputField,
8+
UIType,
9+
WithBoard,
10+
WithMetadata,
11+
)
12+
from invokeai.app.invocations.model import ModelIdentifierField
13+
from invokeai.app.invocations.primitives import ImageOutput
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
16+
17+
18+
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0")
19+
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
20+
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
21+
22+
image: ImageField = InputField(description="The input image")
23+
image_to_image_model: ModelIdentifierField = InputField(
24+
title="Image-to-Image Model",
25+
description=FieldDescriptions.spandrel_image_to_image_model,
26+
ui_type=UIType.SpandrelImageToImageModel,
27+
)
28+
29+
@torch.inference_mode()
30+
def invoke(self, context: InvocationContext) -> ImageOutput:
31+
image = context.images.get_pil(self.image.image_name)
32+
33+
# Load the model.
34+
spandrel_model_info = context.models.load(self.image_to_image_model)
35+
36+
with spandrel_model_info as spandrel_model:
37+
assert isinstance(spandrel_model, SpandrelImageToImageModel)
38+
39+
# Prepare input image for inference.
40+
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
41+
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
42+
43+
# Run inference.
44+
image_tensor = spandrel_model.run(image_tensor)
45+
46+
# Convert the output tensor to a PIL image.
47+
pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
48+
image_dto = context.images.save(image=pil_image)
49+
return ImageOutput.build(image_dto)

0 commit comments

Comments
 (0)