Skip to content

Commit 3f82c38

Browse files
feat(ui): allow changing model type in MM, fix up base and variant selects
1 parent e348105 commit 3f82c38

File tree

6 files changed

+78
-14
lines changed

6 files changed

+78
-14
lines changed

invokeai/frontend/web/src/features/modelManagerV2/models.ts

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { BaseModelType } from 'features/nodes/types/common';
1+
import type { BaseModelType, ModelType, ModelVariantType } from 'features/nodes/types/common';
22
import type { AnyModelConfig } from 'services/api/types';
33
import {
44
isCLIPEmbedModelConfig,
@@ -151,6 +151,30 @@ export const MODEL_BASE_TO_COLOR: Record<BaseModelType, string> = {
151151
unknown: 'red',
152152
};
153153

154+
/**
155+
* Mapping of model type to human readable name
156+
*/
157+
export const MODEL_TYPE_TO_LONG_NAME: Record<ModelType, string> = {
158+
main: 'Main',
159+
vae: 'VAE',
160+
lora: 'LoRA',
161+
llava_onevision: 'LLaVA OneVision',
162+
control_lora: 'ControlLoRA',
163+
controlnet: 'ControlNet',
164+
t2i_adapter: 'T2I Adapter',
165+
ip_adapter: 'IP Adapter',
166+
embedding: 'Embedding',
167+
onnx: 'ONNX',
168+
clip_vision: 'CLIP Vision',
169+
spandrel_image_to_image: 'Spandrel (Image to Image)',
170+
t5_encoder: 'T5 Encoder',
171+
clip_embed: 'CLIP Embed',
172+
siglip: 'SigLIP',
173+
flux_redux: 'FLUX Redux',
174+
video: 'Video',
175+
unknown: 'Unknown',
176+
};
177+
154178
/**
155179
* Mapping of model base to human readable name
156180
*/
@@ -195,6 +219,12 @@ export const MODEL_BASE_TO_SHORT_NAME: Record<BaseModelType, string> = {
195219
unknown: 'Unknown',
196220
};
197221

222+
export const MODEL_VARIANT_TO_LONG_NAME: Record<ModelVariantType, string> = {
223+
normal: 'Normal',
224+
inpaint: 'Inpaint',
225+
depth: 'Depth',
226+
};
227+
198228
/**
199229
* List of base models that make API requests
200230
*/

invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@ import { useCallback, useMemo } from 'react';
66
import type { Control } from 'react-hook-form';
77
import { useController } from 'react-hook-form';
88
import type { UpdateModelArg } from 'services/api/endpoints/models';
9+
import { objectEntries } from 'tsafe';
910

10-
const options: ComboboxOption[] = [
11-
{ value: 'sd-1', label: MODEL_BASE_TO_LONG_NAME['sd-1'] },
12-
{ value: 'sd-2', label: MODEL_BASE_TO_LONG_NAME['sd-2'] },
13-
{ value: 'sd-3', label: MODEL_BASE_TO_LONG_NAME['sd-3'] },
14-
{ value: 'flux', label: MODEL_BASE_TO_LONG_NAME['flux'] },
15-
{ value: 'sdxl', label: MODEL_BASE_TO_LONG_NAME['sdxl'] },
16-
{ value: 'sdxl-refiner', label: MODEL_BASE_TO_LONG_NAME['sdxl-refiner'] },
17-
];
11+
const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([value, label]) => ({
12+
label,
13+
value,
14+
}));
1815

1916
type Props = {
2017
control: Control<UpdateModelArg['body']>;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
2+
import { Combobox } from '@invoke-ai/ui-library';
3+
import { typedMemo } from 'common/util/typedMemo';
4+
import { MODEL_TYPE_TO_LONG_NAME } from 'features/modelManagerV2/models';
5+
import { useCallback, useMemo } from 'react';
6+
import type { Control } from 'react-hook-form';
7+
import { useController } from 'react-hook-form';
8+
import type { UpdateModelArg } from 'services/api/endpoints/models';
9+
import { objectEntries } from 'tsafe';
10+
11+
const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([value, label]) => ({
12+
label,
13+
value,
14+
}));
15+
16+
type Props = {
17+
control: Control<UpdateModelArg['body']>;
18+
};
19+
20+
const ModelTypeSelect = ({ control }: Props) => {
21+
const { field } = useController({ control, name: 'type' });
22+
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
23+
const onChange = useCallback<ComboboxOnChange>(
24+
(v) => {
25+
field.onChange(v?.value);
26+
},
27+
[field]
28+
);
29+
return <Combobox value={value} options={options} onChange={onChange} />;
30+
};
31+
32+
export default typedMemo(ModelTypeSelect);

invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
22
import { Combobox } from '@invoke-ai/ui-library';
33
import { typedMemo } from 'common/util/typedMemo';
4+
import { MODEL_VARIANT_TO_LONG_NAME } from 'features/modelManagerV2/models';
45
import { useCallback, useMemo } from 'react';
56
import type { Control } from 'react-hook-form';
67
import { useController } from 'react-hook-form';
78
import type { UpdateModelArg } from 'services/api/endpoints/models';
9+
import { objectEntries } from 'tsafe';
810

9-
const options: ComboboxOption[] = [
10-
{ value: 'normal', label: 'Normal' },
11-
{ value: 'inpaint', label: 'Inpaint' },
12-
{ value: 'depth', label: 'Depth' },
13-
];
11+
const options: ComboboxOption[] = objectEntries(MODEL_VARIANT_TO_LONG_NAME).map(([value, label]) => ({ label, value }));
1412

1513
type Props = {
1614
control: Control<UpdateModelArg['body']>;

invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoi
2222
import type { AnyModelConfig } from 'services/api/types';
2323

2424
import BaseModelSelect from './Fields/BaseModelSelect';
25+
import ModelTypeSelect from './Fields/ModelTypeSelect';
2526
import ModelVariantSelect from './Fields/ModelVariantSelect';
2627
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
2728
import { ModelFooter } from './ModelFooter';
@@ -127,6 +128,10 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
127128
</Heading>
128129
)}
129130
<SimpleGrid columns={2} gap={4}>
131+
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
132+
<FormLabel>{t('modelManager.modelType')}</FormLabel>
133+
<ModelTypeSelect control={form.control} />
134+
</FormControl>
130135
{modelConfig.type !== 'clip_vision' && (
131136
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
132137
<FormLabel>{t('modelManager.baseModel')}</FormLabel>

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ export const zModelType = z.enum([
129129
'video',
130130
'unknown',
131131
]);
132+
export type ModelType = z.infer<typeof zModelType>;
132133
const zSubModelType = z.enum([
133134
'unet',
134135
'transformer',
@@ -148,6 +149,7 @@ export type SubModelType = z.infer<typeof zSubModelType>;
148149

149150
export const zClipVariantType = z.enum(['large', 'gigantic']);
150151
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
152+
export type ModelVariantType = z.infer<typeof zModelVariantType>;
151153
export const zModelFormat = z.enum([
152154
'omi',
153155
'diffusers',

0 commit comments

Comments
 (0)