Skip to content

Commit b68871a

Browse files
refactor(ui): move model categorisation-ish logic to central location, simplify model manager models list
1 parent 3f3f941 commit b68871a

File tree

29 files changed

+637
-482
lines changed

29 files changed

+637
-482
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class ModelConfigBase(ABC, BaseModel):
137137

138138
@staticmethod
139139
def json_schema_extra(schema: dict[str, Any]) -> None:
140-
schema["required"].extend(["key", "type", "format"])
140+
schema["required"].extend(["key", "base", "type", "format"])
141141

142142
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
143143

@@ -172,7 +172,8 @@ def __init_subclass__(cls, **kwargs):
172172
super().__init_subclass__(**kwargs)
173173
if issubclass(cls, LegacyProbeMixin):
174174
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
175-
elif cls is not UnknownModelConfig:
175+
# Cannot use `elif isinstance(cls, UnknownModelConfig)` because UnknownModelConfig is not defined yet
176+
elif cls.__name__ != "UnknownModelConfig":
176177
ModelConfigBase.USING_CLASSIFY_API.add(cls)
177178

178179
@staticmethod
@@ -274,16 +275,17 @@ def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
274275

275276

276277
class UnknownModelConfig(ModelConfigBase):
278+
base: Literal[BaseModelType.Any] = BaseModelType.Any
277279
type: Literal[ModelType.Unknown] = ModelType.Unknown
278280
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
279281

280282
@classmethod
281-
def matches(cls, *args, **kwargs) -> bool:
282-
raise NotImplementedError("UnknownModelConfig cannot match anything")
283+
def matches(cls, mod: ModelOnDisk) -> bool:
284+
return False
283285

284286
@classmethod
285-
def parse(cls, *args, **kwargs) -> dict[str, Any]:
286-
raise NotImplementedError("UnknownModelConfig cannot parse anything")
287+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
288+
return {}
287289

288290

289291
class CheckpointConfigBase(ABC, BaseModel):

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import {
1111
selectCanvasSlice,
1212
} from 'features/controlLayers/store/selectors';
1313
import { getEntityIdentifier } from 'features/controlLayers/store/types';
14+
import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/modelManagerV2/models';
1415
import { modelSelected } from 'features/parameters/store/actions';
15-
import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/parameters/types/constants';
1616
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
1717
import { toast } from 'features/toast/toast';
1818
import { t } from 'i18next';

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import type { Logger } from 'roarr';
3737
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
3838
import type { AnyModelConfig } from 'services/api/types';
3939
import {
40-
isCLIPEmbedModelConfig,
40+
isCLIPEmbedModelConfigOrSubmodel,
4141
isControlLayerModelConfig,
4242
isControlNetModelConfig,
4343
isFluxReduxModelConfig,
@@ -48,7 +48,7 @@ import {
4848
isNonRefinerMainModelConfig,
4949
isRefinerMainModelModelConfig,
5050
isSpandrelImageToImageModelConfig,
51-
isT5EncoderModelConfig,
51+
isT5EncoderModelConfigOrSubmodel,
5252
isVideoModelConfig,
5353
} from 'services/api/types';
5454
import type { JsonObject } from 'type-fest';
@@ -418,7 +418,7 @@ const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) =
418418

419419
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
420420
const selectedT5EncoderModel = state.params.t5EncoderModel;
421-
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
421+
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfigOrSubmodel(m));
422422

423423
// If the currently selected model is available, we don't need to do anything
424424
if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) {
@@ -446,7 +446,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
446446

447447
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
448448
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
449-
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));
449+
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfigOrSubmodel(m));
450450

451451
// If the currently selected model is available, we don't need to do anything
452452
if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) {

invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Konva from 'konva';
1717
import { atom, computed } from 'nanostores';
1818
import type { Logger } from 'roarr';
1919
import { serializeError } from 'serialize-error';
20-
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
20+
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
2121
import { isControlLayerModelConfig } from 'services/api/types';
2222
import stableHash from 'stable-hash';
2323
import type { Equals } from 'tsafe';
@@ -202,11 +202,19 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
202202
createInitialFilterConfig = (): FilterConfig => {
203203
if (this.parent.type === 'control_layer_adapter' && this.parent.state.controlAdapter.model) {
204204
// If the parent is a control layer adapter, we should check if the model has a default filter and set it if so
205-
const selectModelConfig = buildSelectModelConfig(
206-
this.parent.state.controlAdapter.model.key,
207-
isControlLayerModelConfig
208-
);
209-
const modelConfig = this.manager.stateApi.runSelector(selectModelConfig);
205+
const key = this.parent.state.controlAdapter.model.key;
206+
const modelConfig = this.manager.stateApi.runSelector((state) => {
207+
const { data } = selectModelConfigsQuery(state);
208+
if (!data) {
209+
return null;
210+
}
211+
return (
212+
modelConfigsAdapterSelectors
213+
.selectAll(data)
214+
.filter(isControlLayerModelConfig)
215+
.find((m) => m.key === key) ?? null
216+
);
217+
});
210218
// This always returns a filter
211219
const filter = getFilterForModel(modelConfig) ?? IMAGE_FILTERS.canny_edge_detection;
212220
return filter.buildDefaults();

invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSl
1313
import { selectModel } from 'features/controlLayers/store/paramsSlice';
1414
import { selectBbox } from 'features/controlLayers/store/selectors';
1515
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
16+
import { API_BASE_MODELS } from 'features/modelManagerV2/models';
1617
import type { ModelIdentifierField } from 'features/nodes/types/common';
17-
import { API_BASE_MODELS } from 'features/parameters/types/constants';
1818
import Konva from 'konva';
1919
import { atom } from 'nanostores';
2020
import type { Logger } from 'roarr';

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ import {
3535
getScaledBoundingBoxDimensions,
3636
} from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
3737
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
38+
import { API_BASE_MODELS } from 'features/modelManagerV2/models';
3839
import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
39-
import { API_BASE_MODELS } from 'features/parameters/types/constants';
4040
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
4141
import type { IRect } from 'konva/lib/types';
4242
import type { UndoableOptions } from 'redux-undo';

invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ import {
2525
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
2626
import {
2727
API_BASE_MODELS,
28-
CLIP_SKIP_MAP,
2928
SUPPORTS_ASPECT_RATIO_BASE_MODELS,
3029
SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS,
3130
SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS,
3231
SUPPORTS_PIXEL_DIMENSIONS_BASE_MODELS,
3332
SUPPORTS_REF_IMAGES_BASE_MODELS,
3433
SUPPORTS_SEED_BASE_MODELS,
35-
} from 'features/parameters/types/constants';
34+
} from 'features/modelManagerV2/models';
35+
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
3636
import type {
3737
ParameterCanvasCoherenceMode,
3838
ParameterCFGRescaleMultiplier,

invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
66
import type { GroupStatusMap } from 'common/components/Picker/Picker';
77
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
88
import { selectBase } from 'features/controlLayers/store/paramsSlice';
9+
import { API_BASE_MODELS } from 'features/modelManagerV2/models';
910
import { ModelPicker } from 'features/parameters/components/ModelPicker';
10-
import { API_BASE_MODELS } from 'features/parameters/types/constants';
1111
import { memo, useCallback, useMemo } from 'react';
1212
import { useTranslation } from 'react-i18next';
1313
import { useLoRAModels } from 'services/api/hooks/modelsByType';
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import type { BaseModelType } from 'features/nodes/types/common';
2+
import type { AnyModelConfig } from 'services/api/types';
3+
import {
4+
isCLIPEmbedModelConfig,
5+
isCLIPVisionModelConfig,
6+
isControlLoRAModelConfig,
7+
isControlNetModelConfig,
8+
isFluxReduxModelConfig,
9+
isIPAdapterModelConfig,
10+
isLLaVAModelConfig,
11+
isLoRAModelConfig,
12+
isNonRefinerMainModelConfig,
13+
isRefinerMainModelModelConfig,
14+
isSigLipModelConfig,
15+
isSpandrelImageToImageModelConfig,
16+
isT2IAdapterModelConfig,
17+
isT5EncoderModelConfig,
18+
isTIModelConfig,
19+
isUnknownModelConfig,
20+
isVAEModelConfig,
21+
} from 'services/api/types';
22+
23+
type ModelCategoryData = {
24+
i18nKey: string;
25+
filter: (config: AnyModelConfig) => boolean;
26+
};
27+
28+
export const MODEL_CATEGORIES: Record<string, ModelCategoryData> = {
29+
main: {
30+
i18nKey: 'model_manager.category.main_models',
31+
filter: isNonRefinerMainModelConfig,
32+
},
33+
refiner: {
34+
i18nKey: 'model_manager.category.refiner_models',
35+
filter: isRefinerMainModelModelConfig,
36+
},
37+
lora: {
38+
i18nKey: 'model_manager.category.lora_models',
39+
filter: isLoRAModelConfig,
40+
},
41+
embedding: {
42+
i18nKey: 'model_manager.category.embedding_models',
43+
filter: isTIModelConfig,
44+
},
45+
controlnet: {
46+
i18nKey: 'model_manager.category.controlnet_models',
47+
filter: isControlNetModelConfig,
48+
},
49+
t2i_adapter: {
50+
i18nKey: 'model_manager.category.t2i_adapter_models',
51+
filter: isT2IAdapterModelConfig,
52+
},
53+
t5_encoder: {
54+
i18nKey: 'model_manager.category.t5_encoder_models',
55+
filter: isT5EncoderModelConfig,
56+
},
57+
control_lora: {
58+
i18nKey: 'model_manager.category.control_lora_models',
59+
filter: isControlLoRAModelConfig,
60+
},
61+
clip_embed: {
62+
i18nKey: 'model_manager.category.clip_embed_models',
63+
filter: isCLIPEmbedModelConfig,
64+
},
65+
spandrel: {
66+
i18nKey: 'model_manager.category.spandrel_image_to_image_models',
67+
filter: isSpandrelImageToImageModelConfig,
68+
},
69+
ip_adapter: {
70+
i18nKey: 'model_manager.category.ip_adapter_models',
71+
filter: isIPAdapterModelConfig,
72+
},
73+
vae: {
74+
i18nKey: 'model_manager.category.vae_models',
75+
filter: isVAEModelConfig,
76+
},
77+
clip_vision: {
78+
i18nKey: 'model_manager.category.clip_vision_models',
79+
filter: isCLIPVisionModelConfig,
80+
},
81+
siglip: {
82+
i18nKey: 'model_manager.category.siglip_models',
83+
filter: isSigLipModelConfig,
84+
},
85+
flux_redux: {
86+
i18nKey: 'model_manager.category.flux_redux_models',
87+
filter: isFluxReduxModelConfig,
88+
},
89+
llava_one_vision: {
90+
i18nKey: 'model_manager.category.llava_one_vision_models',
91+
filter: isLLaVAModelConfig,
92+
},
93+
unknown: {
94+
i18nKey: 'model_manager.category.unknown_models',
95+
filter: isUnknownModelConfig,
96+
},
97+
};
98+
99+
/**
100+
* Mapping of model base to its color
101+
*/
102+
export const MODEL_BASE_TO_COLOR: Record<BaseModelType, string> = {
103+
any: 'base',
104+
'sd-1': 'green',
105+
'sd-2': 'teal',
106+
'sd-3': 'purple',
107+
sdxl: 'invokeBlue',
108+
'sdxl-refiner': 'invokeBlue',
109+
flux: 'gold',
110+
cogview4: 'red',
111+
imagen3: 'pink',
112+
imagen4: 'pink',
113+
'chatgpt-4o': 'pink',
114+
'flux-kontext': 'pink',
115+
'gemini-2.5': 'pink',
116+
veo3: 'purple',
117+
runway: 'green',
118+
};
119+
120+
/**
121+
* Mapping of model base to human readable name
122+
*/
123+
export const MODEL_BASE_TO_LONG_NAME: Record<BaseModelType, string> = {
124+
any: 'Any',
125+
'sd-1': 'Stable Diffusion 1.x',
126+
'sd-2': 'Stable Diffusion 2.x',
127+
'sd-3': 'Stable Diffusion 3.x',
128+
sdxl: 'Stable Diffusion XL',
129+
'sdxl-refiner': 'Stable Diffusion XL Refiner',
130+
flux: 'FLUX',
131+
cogview4: 'CogView4',
132+
imagen3: 'Imagen3',
133+
imagen4: 'Imagen4',
134+
'chatgpt-4o': 'ChatGPT 4o',
135+
'flux-kontext': 'Flux Kontext',
136+
'gemini-2.5': 'Gemini 2.5',
137+
veo3: 'Veo3',
138+
runway: 'Runway',
139+
};
140+
141+
/**
142+
* Mapping of model base to short human readable name
143+
*/
144+
export const MODEL_BASE_TO_SHORT_NAME: Record<BaseModelType, string> = {
145+
any: 'Any',
146+
'sd-1': 'SD1.X',
147+
'sd-2': 'SD2.X',
148+
'sd-3': 'SD3.X',
149+
sdxl: 'SDXL',
150+
'sdxl-refiner': 'SDXLR',
151+
flux: 'FLUX',
152+
cogview4: 'CogView4',
153+
imagen3: 'Imagen3',
154+
imagen4: 'Imagen4',
155+
'chatgpt-4o': 'ChatGPT 4o',
156+
'flux-kontext': 'Flux Kontext',
157+
'gemini-2.5': 'Gemini 2.5',
158+
veo3: 'Veo3',
159+
runway: 'Runway',
160+
};
161+
162+
/**
163+
* List of base models that make API requests
164+
*/
165+
export const API_BASE_MODELS: BaseModelType[] = ['imagen3', 'imagen4', 'chatgpt-4o', 'flux-kontext', 'gemini-2.5'];
166+
167+
export const SUPPORTS_SEED_BASE_MODELS: BaseModelType[] = ['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4'];
168+
169+
export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3'];
170+
171+
export const SUPPORTS_REF_IMAGES_BASE_MODELS: BaseModelType[] = [
172+
'sd-1',
173+
'sdxl',
174+
'flux',
175+
'flux-kontext',
176+
'chatgpt-4o',
177+
'gemini-2.5',
178+
];
179+
180+
export const SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS: BaseModelType[] = [
181+
'sd-1',
182+
'sd-2',
183+
'sdxl',
184+
'cogview4',
185+
'sd-3',
186+
'imagen3',
187+
'imagen4',
188+
];
189+
190+
export const SUPPORTS_PIXEL_DIMENSIONS_BASE_MODELS: BaseModelType[] = [
191+
'sd-1',
192+
'sd-2',
193+
'sd-3',
194+
'sdxl',
195+
'flux',
196+
'cogview4',
197+
];
198+
199+
export const SUPPORTS_ASPECT_RATIO_BASE_MODELS: BaseModelType[] = [
200+
'sd-1',
201+
'sd-2',
202+
'sd-3',
203+
'sdxl',
204+
'flux',
205+
'cogview4',
206+
'imagen3',
207+
'imagen4',
208+
'flux-kontext',
209+
'chatgpt-4o',
210+
];
211+
212+
export const VIDEO_BASE_MODELS = ['veo3', 'runway'];
213+
214+
export const REQUIRES_STARTING_FRAME_BASE_MODELS = ['runway'];

0 commit comments

Comments
 (0)