Skip to content

Commit 54aa690

Browse files
feat(ui): update invocation parsing to handle new ui_model_[base|type|variant] attrs
1 parent e6d9dac commit 54aa690

File tree

5 files changed

+65
-10
lines changed

5 files changed

+65
-10
lines changed

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import type { FieldComponentProps } from './types';
1212
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
1313

1414
const ModelIdentifierFieldInputComponent = (props: Props) => {
15-
const { nodeId, field } = props;
15+
const { nodeId, field, fieldTemplate } = props;
1616
const dispatch = useAppDispatch();
1717
const { data, isLoading } = useGetModelConfigsQuery();
1818
const onChange = useCallback(
@@ -36,8 +36,28 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
3636
return EMPTY_ARRAY;
3737
}
3838

39-
return modelConfigsAdapterSelectors.selectAll(data);
40-
}, [data]);
39+
if (!fieldTemplate.ui_model_base && !fieldTemplate.ui_model_type) {
40+
return modelConfigsAdapterSelectors.selectAll(data);
41+
}
42+
43+
return modelConfigsAdapterSelectors.selectAll(data).filter((config) => {
44+
if (fieldTemplate.ui_model_base && !fieldTemplate.ui_model_base.includes(config.base)) {
45+
return false;
46+
}
47+
if (fieldTemplate.ui_model_type && !fieldTemplate.ui_model_type.includes(config.type)) {
48+
return false;
49+
}
50+
if (
51+
fieldTemplate.ui_model_variant &&
52+
'variant' in config &&
53+
config.variant &&
54+
!fieldTemplate.ui_model_variant.includes(config.variant)
55+
) {
56+
return false;
57+
}
58+
return true;
59+
});
60+
}, [data, fieldTemplate.ui_model_base, fieldTemplate.ui_model_type, fieldTemplate.ui_model_variant]);
4161

4262
return (
4363
<ModelFieldCombobox

invokeai/frontend/web/src/features/nodes/types/common.test-d.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@ import type {
1212
SchedulerField,
1313
SubModelType,
1414
T2IAdapterField,
15+
zClipVariantType,
16+
zModelVariantType,
1517
} from 'features/nodes/types/common';
16-
import type { Invocation, ModelType, S } from 'services/api/types';
18+
import type { Invocation, S } from 'services/api/types';
1719
import type { Equals, Extends } from 'tsafe';
1820
import { assert } from 'tsafe';
1921
import { describe, test } from 'vitest';
22+
import type z from 'zod';
2023

2124
/**
2225
* These types originate from the server and are recreated as zod schemas manually, for use at runtime.
@@ -38,7 +41,8 @@ describe('Common types', () => {
3841
test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
3942
test('ModelIdentifier', () => assert<Equals<BaseModelType, S['BaseModelType']>>());
4043
test('ModelIdentifier', () => assert<Equals<SubModelType, S['SubModelType']>>());
41-
test('ModelIdentifier', () => assert<Equals<ModelType, S['ModelType']>>());
44+
test('ClipVariantType', () => assert<Equals<z.infer<typeof zClipVariantType>, S['ClipVariantType']>>());
45+
test('ModelVariantType', () => assert<Equals<z.infer<typeof zModelVariantType>, S['ModelVariantType']>>());
4246

4347
// Misc types
4448
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
7373
// #endregion
7474

7575
// #region Model-related schemas
76-
const zBaseModel = z.enum([
76+
export const zBaseModelType = z.enum([
7777
'any',
7878
'sd-1',
7979
'sd-2',
@@ -90,7 +90,7 @@ const zBaseModel = z.enum([
9090
'veo3',
9191
'runway',
9292
]);
93-
export type BaseModelType = z.infer<typeof zBaseModel>;
93+
export type BaseModelType = z.infer<typeof zBaseModelType>;
9494
export const zMainModelBase = z.enum([
9595
'sd-1',
9696
'sd-2',
@@ -143,11 +143,15 @@ const zSubModelType = z.enum([
143143
'safety_checker',
144144
]);
145145
export type SubModelType = z.infer<typeof zSubModelType>;
146+
147+
export const zClipVariantType = z.enum(['large', 'gigantic']);
148+
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
149+
146150
export const zModelIdentifierField = z.object({
147151
key: z.string().min(1),
148152
hash: z.string().min(1),
149153
name: z.string().min(1),
150-
base: zBaseModel,
154+
base: zBaseModelType,
151155
type: zModelType,
152156
submodel_type: zSubModelType.nullish(),
153157
});

invokeai/frontend/web/src/features/nodes/types/field.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,17 @@ import { assert } from 'tsafe';
99
import { z } from 'zod';
1010

1111
import type { ImageField } from './common';
12-
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
12+
import {
13+
zBaseModelType,
14+
zBoardField,
15+
zClipVariantType,
16+
zColorField,
17+
zImageField,
18+
zModelIdentifierField,
19+
zModelType,
20+
zModelVariantType,
21+
zSchedulerField,
22+
} from './common';
1323

1424
/**
1525
* zod schemas & inferred types for fields.
@@ -60,6 +70,9 @@ const zFieldInputTemplateBase = zFieldTemplateBase.extend({
6070
default: z.undefined(),
6171
ui_component: zFieldUIComponent.nullish(),
6272
ui_choice_labels: z.record(z.string(), z.string()).nullish(),
73+
ui_model_base: z.array(zBaseModelType).nullish(),
74+
ui_model_type: z.array(zModelType).nullish(),
75+
ui_model_variant: z.array(zModelVariantType.or(zClipVariantType)).nullish(),
6376
});
6477
const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
6578
fieldKind: z.literal('output'),

invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,18 @@ export const buildFieldInputTemplate = (
892892
fieldName: string,
893893
fieldType: FieldType
894894
): FieldInputTemplate => {
895-
const { input, ui_hidden, ui_component, ui_type, ui_order, ui_choice_labels, orig_required: required } = fieldSchema;
895+
const {
896+
input,
897+
ui_hidden,
898+
ui_component,
899+
ui_type,
900+
ui_order,
901+
ui_choice_labels,
902+
orig_required: required,
903+
ui_model_base,
904+
ui_model_type,
905+
ui_model_variant,
906+
} = fieldSchema;
896907

897908
// This is the base field template that is common to all fields. The builder function will add all other
898909
// properties to this template.
@@ -908,6 +919,9 @@ export const buildFieldInputTemplate = (
908919
ui_type,
909920
ui_order,
910921
ui_choice_labels,
922+
ui_model_base,
923+
ui_model_type,
924+
ui_model_variant,
911925
};
912926

913927
if (isStatefulFieldType(fieldType)) {

0 commit comments

Comments
 (0)