Skip to content

Commit 7ad32dc

Browse files
authored
Add support for Spandrel Image-to-Image models (e.g. ESRGAN, Real-ESRGAN, Swin-IR, DAT, etc.) (#6556)
## Summary - Add support for all [spandrel](https://github.com/chaiNNer-org/spandrel) image-to-image models - this is a collection of many popular super-resolution models (e.g. ESRGAN, Real-ESRGAN, SwinIR, DAT, etc.) Examples of supported models: - DAT: https://drive.google.com/drive/folders/1iBdf_-LVZuz_PAbFtuxSKd_11RL1YKxM - SwinIR: https://github.com/JingyunLiang/SwinIR/releases - Any ESRGAN / Real-ESRGAN model ## Related Issues Closes #6394 ## QA Instructions - [x] Test that unsupported models still fail the probe (i.e. no false positive spandrel models) - [x] Test adding a few non-spandrel model types - [x] Test adding a handful of spandrel model types: ESRGAN, Real-ESRGAN, SwinIR, DAT - [x] Verify model size estimation for the model cache - [x] Test using the spandrel models in a practical image upscaling workflow ## Merge Plan - [x] Get approval from @brandonrising and @maryhipp before merging - this PR has commercial implications. - [x] Merge #6571 and change the target branch to `main` ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
2 parents 7905a46 + 81991e0 commit 7ad32dc

File tree

23 files changed

+719
-157
lines changed

23 files changed

+719
-157
lines changed

invokeai/app/invocations/fields.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
4848
ControlNetModel = "ControlNetModelField"
4949
IPAdapterModel = "IPAdapterModelField"
5050
T2IAdapterModel = "T2IAdapterModelField"
51+
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
5152
# endregion
5253

5354
# region Misc Field Types
@@ -134,6 +135,7 @@ class FieldDescriptions:
134135
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
135136
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
136137
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
138+
spandrel_image_to_image_model = "Image-to-Image model"
137139
lora_weight = "The weight at which the LoRA is applied to each model"
138140
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
139141
raw_prompt = "Raw prompt text (no parsing)"
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
3+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
4+
from invokeai.app.invocations.fields import (
5+
FieldDescriptions,
6+
ImageField,
7+
InputField,
8+
UIType,
9+
WithBoard,
10+
WithMetadata,
11+
)
12+
from invokeai.app.invocations.model import ModelIdentifierField
13+
from invokeai.app.invocations.primitives import ImageOutput
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
16+
17+
18+
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0")
19+
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
20+
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
21+
22+
image: ImageField = InputField(description="The input image")
23+
image_to_image_model: ModelIdentifierField = InputField(
24+
title="Image-to-Image Model",
25+
description=FieldDescriptions.spandrel_image_to_image_model,
26+
ui_type=UIType.SpandrelImageToImageModel,
27+
)
28+
29+
@torch.inference_mode()
30+
def invoke(self, context: InvocationContext) -> ImageOutput:
31+
image = context.images.get_pil(self.image.image_name)
32+
33+
# Load the model.
34+
spandrel_model_info = context.models.load(self.image_to_image_model)
35+
36+
with spandrel_model_info as spandrel_model:
37+
assert isinstance(spandrel_model, SpandrelImageToImageModel)
38+
39+
# Prepare input image for inference.
40+
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
41+
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
42+
43+
# Run inference.
44+
image_tensor = spandrel_model.run(image_tensor)
45+
46+
# Convert the output tensor to a PIL image.
47+
pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
48+
image_dto = context.images.save(image=pil_image)
49+
return ImageOutput.build(image_dto)

invokeai/backend/model_manager/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class ModelType(str, Enum):
6767
IPAdapter = "ip_adapter"
6868
CLIPVision = "clip_vision"
6969
T2IAdapter = "t2i_adapter"
70+
SpandrelImageToImage = "spandrel_image_to_image"
7071

7172

7273
class SubModelType(str, Enum):
@@ -371,6 +372,17 @@ def get_tag() -> Tag:
371372
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
372373

373374

375+
class SpandrelImageToImageConfig(ModelConfigBase):
376+
"""Model config for Spandrel Image to Image models."""
377+
378+
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
379+
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
380+
381+
@staticmethod
382+
def get_tag() -> Tag:
383+
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
384+
385+
374386
def get_model_discriminator_value(v: Any) -> str:
375387
"""
376388
Computes the discriminator value for a model config.
@@ -407,6 +419,7 @@ def get_model_discriminator_value(v: Any) -> str:
407419
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
408420
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
409421
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
422+
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
410423
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
411424
],
412425
Discriminator(get_model_discriminator_value),
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from pathlib import Path
2+
from typing import Optional
3+
4+
import torch
5+
6+
from invokeai.backend.model_manager.config import (
7+
AnyModel,
8+
AnyModelConfig,
9+
BaseModelType,
10+
ModelFormat,
11+
ModelType,
12+
SubModelType,
13+
)
14+
from invokeai.backend.model_manager.load.load_default import ModelLoader
15+
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
16+
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
17+
18+
19+
@ModelLoaderRegistry.register(
20+
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
21+
)
22+
class SpandrelImageToImageModelLoader(ModelLoader):
23+
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
24+
25+
def _load_model(
26+
self,
27+
config: AnyModelConfig,
28+
submodel_type: Optional[SubModelType] = None,
29+
) -> AnyModel:
30+
if submodel_type is not None:
31+
raise ValueError("Unexpected submodel requested for Spandrel model.")
32+
33+
model_path = Path(config.path)
34+
model = SpandrelImageToImageModel.load_from_file(model_path)
35+
36+
torch_dtype = self._torch_dtype
37+
if not model.supports_dtype(torch_dtype):
38+
self._logger.warning(
39+
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
40+
"model. Falling back to 'float32'."
41+
)
42+
torch_dtype = torch.float32
43+
model.to(dtype=torch_dtype)
44+
45+
return model

invokeai/backend/model_manager/load/model_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from invokeai.backend.lora import LoRAModelRaw
1616
from invokeai.backend.model_manager.config import AnyModel
1717
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
18+
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
1819
from invokeai.backend.textual_inversion import TextualInversionModelRaw
1920

2021

@@ -33,7 +34,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
3334
elif isinstance(model, CLIPTokenizer):
3435
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
3536
return 0
36-
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
37+
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
3738
return model.calc_size()
3839
else:
3940
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the

invokeai/backend/model_manager/probe.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Dict, Literal, Optional, Union
55

66
import safetensors.torch
7+
import spandrel
78
import torch
89
from picklescan.scanner import scan_file_path
910

@@ -25,6 +26,7 @@
2526
SchedulerPredictionType,
2627
)
2728
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
29+
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
2830
from invokeai.backend.util.silence_warnings import SilenceWarnings
2931

3032
CkptType = Dict[str | int, Any]
@@ -220,24 +222,46 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
220222
ckpt = ckpt.get("state_dict", ckpt)
221223

222224
for key in [str(k) for k in ckpt.keys()]:
223-
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
225+
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
224226
return ModelType.Main
225-
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
227+
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
226228
return ModelType.VAE
227-
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
229+
elif key.startswith(("lora_te_", "lora_unet_")):
228230
return ModelType.LoRA
229-
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
231+
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
230232
return ModelType.LoRA
231-
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
233+
elif key.startswith(("controlnet", "control_model", "input_blocks")):
232234
return ModelType.ControlNet
233-
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
235+
elif key.startswith(("image_proj.", "ip_adapter.")):
234236
return ModelType.IPAdapter
235237
elif key in {"emb_params", "string_to_param"}:
236238
return ModelType.TextualInversion
237-
else:
238-
# diffusers-ti
239-
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
240-
return ModelType.TextualInversion
239+
240+
# diffusers-ti
241+
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
242+
return ModelType.TextualInversion
243+
244+
# Check if the model can be loaded as a SpandrelImageToImageModel.
245+
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
246+
try:
247+
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
248+
# explored to avoid this:
249+
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
250+
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
251+
# supported on meta tensors.
252+
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
253+
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
254+
# maintain it, and the risk of false positive detections is higher.
255+
SpandrelImageToImageModel.load_from_file(model_path)
256+
return ModelType.SpandrelImageToImage
257+
except spandrel.UnsupportedModelError:
258+
pass
259+
except RuntimeError as e:
260+
if "No such file or directory" in str(e):
261+
# This error is expected if the model_path does not exist (which is the case in some unit tests).
262+
pass
263+
else:
264+
raise e
241265

242266
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
243267

@@ -569,6 +593,11 @@ def get_base_type(self) -> BaseModelType:
569593
raise NotImplementedError()
570594

571595

596+
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
597+
def get_base_type(self) -> BaseModelType:
598+
return BaseModelType.Any
599+
600+
572601
########################################################
573602
# classes for probing folders
574603
#######################################################
@@ -776,6 +805,11 @@ def get_base_type(self) -> BaseModelType:
776805
return BaseModelType.Any
777806

778807

808+
class SpandrelImageToImageFolderProbe(FolderProbeBase):
809+
def get_base_type(self) -> BaseModelType:
810+
raise NotImplementedError()
811+
812+
779813
class T2IAdapterFolderProbe(FolderProbeBase):
780814
def get_base_type(self) -> BaseModelType:
781815
config_file = self.model_path / "config.json"
@@ -805,6 +839,7 @@ def get_base_type(self) -> BaseModelType:
805839
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
806840
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
807841
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
842+
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
808843

809844
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
810845
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@@ -814,5 +849,6 @@ def get_base_type(self) -> BaseModelType:
814849
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
815850
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
816851
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
852+
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
817853

818854
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

invokeai/backend/raw_model.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
1-
"""Base class for 'Raw' models.
2-
3-
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
4-
and is used for type checking of calls to the model patcher. Its main purpose
5-
is to avoid a circular import issues when lora.py tries to import BaseModelType
6-
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
7-
from lora.py.
8-
9-
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
10-
that adds additional methods and attributes.
11-
"""
12-
131
from abc import ABC, abstractmethod
142
from typing import Optional
153

164
import torch
175

186

197
class RawModel(ABC):
20-
"""Abstract base class for 'Raw' model wrappers."""
8+
"""Base class for 'Raw' models.
9+
10+
The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
11+
and is used for type checking of calls to the model patcher. Its main purpose
12+
is to avoid a circular import issues when lora.py tries to import BaseModelType
13+
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
14+
from lora.py.
15+
16+
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
17+
that adds additional methods and attributes.
18+
"""
2119

2220
@abstractmethod
2321
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:

0 commit comments

Comments
 (0)