Skip to content

Commit 035d943

Browse files
feat(ui): support filtering on model format
1 parent bdeb9fb commit 035d943

File tree

6 files changed

+31
-10
lines changed

6 files changed

+31
-10
lines changed

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
5555
) {
5656
return false;
5757
}
58+
if (fieldTemplate.ui_model_format && !fieldTemplate.ui_model_format.includes(config.format)) {
59+
return false;
60+
}
5861
return true;
5962
});
60-
}, [data, fieldTemplate.ui_model_base, fieldTemplate.ui_model_type, fieldTemplate.ui_model_variant]);
63+
}, [data, fieldTemplate]);
6164

6265
return (
6366
<ModelFieldCombobox

invokeai/frontend/web/src/features/nodes/hooks/useInputFieldNamesByStatus.ts

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,16 @@ import { createSelector } from '@reduxjs/toolkit';
22
import { useAppSelector } from 'app/store/storeHooks';
33
import { useInvocationNodeContext } from 'features/nodes/components/flow/nodes/Invocation/context';
44
import type { FieldInputTemplate } from 'features/nodes/types/field';
5-
import { isSingleOrCollection } from 'features/nodes/types/field';
6-
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
5+
import { isSingleOrCollection, isStatefulFieldType } from 'features/nodes/types/field';
76
import { useMemo } from 'react';
87

98
const isConnectionInputField = (field: FieldInputTemplate) => {
10-
return (
11-
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
12-
);
9+
return (field.input === 'connection' && !isSingleOrCollection(field.type)) || !isStatefulFieldType(field.type);
1310
};
1411

1512
const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
1613
return (
17-
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
18-
field.type.name in TEMPLATE_BUILDER_MAP
14+
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) && isStatefulFieldType(field.type)
1915
);
2016
};
2117

invokeai/frontend/web/src/features/nodes/types/common.test-d.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import type {
1313
SubModelType,
1414
T2IAdapterField,
1515
zClipVariantType,
16+
zModelFormat,
1617
zModelVariantType,
1718
} from 'features/nodes/types/common';
1819
import type { Invocation, S } from 'services/api/types';
@@ -43,6 +44,7 @@ describe('Common types', () => {
4344
test('ModelIdentifier', () => assert<Equals<SubModelType, S['SubModelType']>>());
4445
test('ClipVariantType', () => assert<Equals<z.infer<typeof zClipVariantType>, S['ClipVariantType']>>());
4546
test('ModelVariantType', () => assert<Equals<z.infer<typeof zModelVariantType>, S['ModelVariantType']>>());
47+
test('ModelFormat', () => assert<Equals<z.infer<typeof zModelFormat>, S['ModelFormat']>>());
4648

4749
// Misc types
4850
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,22 @@ export type SubModelType = z.infer<typeof zSubModelType>;
146146

147147
export const zClipVariantType = z.enum(['large', 'gigantic']);
148148
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
149+
export const zModelFormat = z.enum([
150+
'omi',
151+
'diffusers',
152+
'checkpoint',
153+
'lycoris',
154+
'onnx',
155+
'olive',
156+
'embedding_file',
157+
'embedding_folder',
158+
'invokeai',
159+
't5_encoder',
160+
'bnb_quantized_int8b',
161+
'bnb_quantized_nf4b',
162+
'gguf_quantized',
163+
'api',
164+
]);
149165

150166
export const zModelIdentifierField = z.object({
151167
key: z.string().min(1),

invokeai/frontend/web/src/features/nodes/types/field.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import {
1515
zClipVariantType,
1616
zColorField,
1717
zImageField,
18+
zModelFormat,
1819
zModelIdentifierField,
1920
zModelType,
2021
zModelVariantType,
@@ -73,6 +74,7 @@ const zFieldInputTemplateBase = zFieldTemplateBase.extend({
7374
ui_model_base: z.array(zBaseModelType).nullish(),
7475
ui_model_type: z.array(zModelType).nullish(),
7576
ui_model_variant: z.array(zModelVariantType.or(zClipVariantType)).nullish(),
77+
ui_model_format: z.array(zModelFormat).nullish(),
7678
});
7779
const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
7880
fieldKind: z.literal('output'),

invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ const buildImageGeneratorFieldInputTemplate: FieldInputTemplateBuilder<ImageGene
449449
return template;
450450
};
451451

452-
export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
452+
const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
453453
BoardField: buildBoardFieldInputTemplate,
454454
BooleanField: buildBooleanFieldInputTemplate,
455455
ColorField: buildColorFieldInputTemplate,
@@ -464,7 +464,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
464464
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
465465
StringGeneratorField: buildStringGeneratorFieldInputTemplate,
466466
ImageGeneratorField: buildImageGeneratorFieldInputTemplate,
467-
} as const;
467+
};
468468

469469
export const buildFieldInputTemplate = (
470470
fieldSchema: InvocationFieldSchema,
@@ -482,6 +482,7 @@ export const buildFieldInputTemplate = (
482482
ui_model_base,
483483
ui_model_type,
484484
ui_model_variant,
485+
ui_model_format,
485486
} = fieldSchema;
486487

487488
// This is the base field template that is common to all fields. The builder function will add all other
@@ -501,6 +502,7 @@ export const buildFieldInputTemplate = (
501502
ui_model_base,
502503
ui_model_type,
503504
ui_model_variant,
505+
ui_model_format,
504506
};
505507

506508
if (isStatefulFieldType(fieldType)) {

0 commit comments

Comments
 (0)