Skip to content

Commit 29c8ddf

Browse files
committed
WIP - A bunch of boilerplate to support Spandrel Image-to-Image models throughout the model manager and the frontend.
1 parent 95079dc commit 29c8ddf

File tree

15 files changed

+287
-19
lines changed

15 files changed

+287
-19
lines changed

invokeai/app/invocations/fields.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
4848
ControlNetModel = "ControlNetModelField"
4949
IPAdapterModel = "IPAdapterModelField"
5050
T2IAdapterModel = "T2IAdapterModelField"
51+
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
5152
# endregion
5253

5354
# region Misc Field Types
@@ -134,6 +135,7 @@ class FieldDescriptions:
134135
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
135136
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
136137
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
138+
spandrel_image_to_image_model = "Spandrel Image-to-Image model"
137139
lora_weight = "The weight at which the LoRA is applied to each model"
138140
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
139141
raw_prompt = "Raw prompt text (no parsing)"

invokeai/backend/model_manager/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,17 @@ def get_tag() -> Tag:
373373
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
374374

375375

376+
class SpandrelImageToImageConfig(ModelConfigBase):
377+
"""Model config for Spandrel Image to Image models."""
378+
379+
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
380+
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
381+
382+
@staticmethod
383+
def get_tag() -> Tag:
384+
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
385+
386+
376387
def get_model_discriminator_value(v: Any) -> str:
377388
"""
378389
Computes the discriminator value for a model config.
@@ -409,6 +420,7 @@ def get_model_discriminator_value(v: Any) -> str:
409420
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
410421
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
411422
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
423+
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
412424
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
413425
],
414426
Discriminator(get_model_discriminator_value),

invokeai/backend/model_manager/probe.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,14 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
243243

244244
# Check if the model can be loaded as a SpandrelImageToImageModel.
245245
try:
246-
_ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
246+
# TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected.
247+
# _ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
248+
_ = SpandrelImageToImageModel.load_from_file(model_path)
247249
return ModelType.SpandrelImageToImage
248-
except Exception:
250+
except Exception as e:
249251
# TODO(ryand): Catch a more specific exception type here if we can.
252+
# TODO(ryand): Delete this print statement.
253+
print(e)
250254
pass
251255

252256
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@@ -579,9 +583,9 @@ def get_base_type(self) -> BaseModelType:
579583
raise NotImplementedError()
580584

581585

582-
class SpandrelImageToImageModelProbe(CheckpointProbeBase):
586+
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
583587
def get_base_type(self) -> BaseModelType:
584-
raise NotImplementedError()
588+
return BaseModelType.Any
585589

586590

587591
########################################################
@@ -791,6 +795,11 @@ def get_base_type(self) -> BaseModelType:
791795
return BaseModelType.Any
792796

793797

798+
class SpandrelImageToImageFolderProbe(FolderProbeBase):
799+
def get_base_type(self) -> BaseModelType:
800+
raise NotImplementedError()
801+
802+
794803
class T2IAdapterFolderProbe(FolderProbeBase):
795804
def get_base_type(self) -> BaseModelType:
796805
config_file = self.model_path / "config.json"
@@ -820,6 +829,7 @@ def get_base_type(self) -> BaseModelType:
820829
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
821830
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
822831
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
832+
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
823833

824834
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
825835
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@@ -829,5 +839,6 @@ def get_base_type(self) -> BaseModelType:
829839
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
830840
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
831841
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
842+
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
832843

833844
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import {
3232
isSDXLMainModelFieldInputTemplate,
3333
isSDXLRefinerModelFieldInputInstance,
3434
isSDXLRefinerModelFieldInputTemplate,
35+
isSpandrelImageToImageModelFieldInputInstance,
36+
isSpandrelImageToImageModelFieldInputTemplate,
3537
isStringFieldInputInstance,
3638
isStringFieldInputTemplate,
3739
isT2IAdapterModelFieldInputInstance,
@@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
5456
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
5557
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
5658
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
59+
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
5760
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
5861
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
5962
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
@@ -125,6 +128,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
125128
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
126129
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
127130
}
131+
132+
if (isSpandrelImageToImageModelFieldInputInstance(fieldInstance) && isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate)) {
133+
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
134+
}
135+
128136
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
129137
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
130138
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
2+
import { useAppDispatch } from 'app/store/storeHooks';
3+
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
4+
import { fieldSpandrelImageToImageModelValueChanged, } from 'features/nodes/store/nodesSlice';
5+
import type {
6+
SpandrelImageToImageModelFieldInputInstance,
7+
SpandrelImageToImageModelFieldInputTemplate,
8+
} from 'features/nodes/types/field';
9+
import { memo, useCallback } from 'react';
10+
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
11+
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
12+
13+
import type { FieldComponentProps } from './types';
14+
15+
const SpandrelImageToImageModelFieldInputComponent = (
16+
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
17+
) => {
18+
const { nodeId, field } = props;
19+
const dispatch = useAppDispatch();
20+
21+
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
22+
23+
const _onChange = useCallback(
24+
(value: SpandrelImageToImageModelConfig | null) => {
25+
if (!value) {
26+
return;
27+
}
28+
dispatch(
29+
30+
fieldSpandrelImageToImageModelValueChanged({
31+
nodeId,
32+
fieldName: field.name,
33+
value,
34+
})
35+
);
36+
},
37+
[dispatch, field.name, nodeId]
38+
);
39+
40+
const { options, value, onChange } = useGroupedModelCombobox({
41+
modelConfigs,
42+
onChange: _onChange,
43+
selectedModel: field.value,
44+
isLoading,
45+
});
46+
47+
return (
48+
<Tooltip label={value?.description}>
49+
<FormControl className="nowheel nodrag" isInvalid={!value}>
50+
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
51+
</FormControl>
52+
</Tooltip>
53+
);
54+
};
55+
56+
export default memo(SpandrelImageToImageModelFieldInputComponent);

invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import type {
1919
ModelIdentifierFieldValue,
2020
SchedulerFieldValue,
2121
SDXLRefinerModelFieldValue,
22+
SpandrelImageToImageModelFieldValue,
2223
StatefulFieldValue,
2324
StringFieldValue,
2425
T2IAdapterModelFieldValue,
@@ -39,6 +40,7 @@ import {
3940
zModelIdentifierFieldValue,
4041
zSchedulerFieldValue,
4142
zSDXLRefinerModelFieldValue,
43+
zSpandrelImageToImageModelFieldValue,
4244
zStatefulFieldValue,
4345
zStringFieldValue,
4446
zT2IAdapterModelFieldValue,
@@ -333,6 +335,9 @@ export const nodesSlice = createSlice({
333335
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
334336
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
335337
},
338+
fieldSpandrelImageToImageModelValueChanged: (state, action: FieldValueAction<SpandrelImageToImageModelFieldValue>) => {
339+
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
340+
},
336341
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
337342
fieldValueReducer(state, action, zEnumFieldValue);
338343
},
@@ -384,6 +389,7 @@ export const {
384389
fieldImageValueChanged,
385390
fieldIPAdapterModelValueChanged,
386391
fieldT2IAdapterModelValueChanged,
392+
fieldSpandrelImageToImageModelValueChanged,
387393
fieldLabelChanged,
388394
fieldLoRAModelValueChanged,
389395
fieldModelIdentifierValueChanged,

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ const zModelType = z.enum([
6666
'embedding',
6767
'onnx',
6868
'clip_vision',
69+
'spandrel_image_to_image',
6970
]);
7071
const zSubModelType = z.enum([
7172
'unet',

invokeai/frontend/web/src/features/nodes/types/constants.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export const MODEL_TYPES = [
3838
'VAEField',
3939
'CLIPField',
4040
'T2IAdapterModelField',
41+
'SpandrelImageToImageModelField',
4142
];
4243

4344
/**
@@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
6263
MainModelField: 'teal.500',
6364
SDXLMainModelField: 'teal.500',
6465
SDXLRefinerModelField: 'teal.500',
66+
SpandrelImageToImageModelField: 'teal.500',
6567
StringField: 'yellow.500',
6668
T2IAdapterField: 'teal.500',
6769
T2IAdapterModelField: 'teal.500',

invokeai/frontend/web/src/features/nodes/types/field.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
139139
name: z.literal('T2IAdapterModelField'),
140140
originalType: zStatelessFieldType.optional(),
141141
});
142+
const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
143+
name: z.literal('SpandrelImageToImageModelField'),
144+
originalType: zStatelessFieldType.optional(),
145+
});
142146
const zSchedulerFieldType = zFieldTypeBase.extend({
143147
name: z.literal('SchedulerField'),
144148
originalType: zStatelessFieldType.optional(),
@@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([
160164
zControlNetModelFieldType,
161165
zIPAdapterModelFieldType,
162166
zT2IAdapterModelFieldType,
167+
zSpandrelImageToImageModelFieldType,
163168
zColorFieldType,
164169
zSchedulerFieldType,
165170
]);
@@ -581,6 +586,30 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
581586
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
582587
// #endregion
583588

589+
// #region SpandrelModelToModelField
590+
591+
export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
592+
const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
593+
value: zSpandrelImageToImageModelFieldValue,
594+
});
595+
const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
596+
type: zSpandrelImageToImageModelFieldType,
597+
originalType: zFieldType.optional(),
598+
default: zSpandrelImageToImageModelFieldValue,
599+
});
600+
const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
601+
type: zSpandrelImageToImageModelFieldType,
602+
});
603+
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
604+
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
605+
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
606+
export const isSpandrelImageToImageModelFieldInputInstance = (val: unknown): val is SpandrelImageToImageModelFieldInputInstance =>
607+
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
608+
export const isSpandrelImageToImageModelFieldInputTemplate = (val: unknown): val is SpandrelImageToImageModelFieldInputTemplate =>
609+
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
610+
// #endregion
611+
612+
584613
// #region SchedulerField
585614

586615
export const zSchedulerFieldValue = zSchedulerField.optional();
@@ -667,6 +696,7 @@ export const zStatefulFieldValue = z.union([
667696
zControlNetModelFieldValue,
668697
zIPAdapterModelFieldValue,
669698
zT2IAdapterModelFieldValue,
699+
zSpandrelImageToImageModelFieldValue,
670700
zColorFieldValue,
671701
zSchedulerFieldValue,
672702
]);
@@ -694,6 +724,7 @@ const zStatefulFieldInputInstance = z.union([
694724
zControlNetModelFieldInputInstance,
695725
zIPAdapterModelFieldInputInstance,
696726
zT2IAdapterModelFieldInputInstance,
727+
zSpandrelImageToImageModelFieldInputInstance,
697728
zColorFieldInputInstance,
698729
zSchedulerFieldInputInstance,
699730
]);
@@ -722,6 +753,7 @@ const zStatefulFieldInputTemplate = z.union([
722753
zControlNetModelFieldInputTemplate,
723754
zIPAdapterModelFieldInputTemplate,
724755
zT2IAdapterModelFieldInputTemplate,
756+
zSpandrelImageToImageModelFieldInputTemplate,
725757
zColorFieldInputTemplate,
726758
zSchedulerFieldInputTemplate,
727759
zStatelessFieldInputTemplate,
@@ -751,6 +783,7 @@ const zStatefulFieldOutputTemplate = z.union([
751783
zControlNetModelFieldOutputTemplate,
752784
zIPAdapterModelFieldOutputTemplate,
753785
zT2IAdapterModelFieldOutputTemplate,
786+
zSpandrelImageToImageModelFieldOutputTemplate,
754787
zColorFieldOutputTemplate,
755788
zSchedulerFieldOutputTemplate,
756789
]);

invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
1818
SDXLRefinerModelField: undefined,
1919
StringField: '',
2020
T2IAdapterModelField: undefined,
21+
SpandrelImageToImageModelField: undefined,
2122
VAEModelField: undefined,
2223
ControlNetModelField: undefined,
2324
};

0 commit comments

Comments
 (0)