Skip to content

Commit bd893cf

Browse files
refactor(ui)refactor(ui): more cleanup of model categories
1 parent b68871a commit bd893cf

File tree

3 files changed

+79
-67
lines changed

3 files changed

+79
-67
lines changed

invokeai/frontend/web/src/features/modelManagerV2/models.ts

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,84 +18,117 @@ import {
1818
isTIModelConfig,
1919
isUnknownModelConfig,
2020
isVAEModelConfig,
21+
isVideoModelConfig,
2122
} from 'services/api/types';
23+
import { objectEntries } from 'tsafe';
2224

23-
type ModelCategoryData = {
25+
import type { FilterableModelType } from './store/modelManagerV2Slice';
26+
27+
export type ModelCategoryData = {
28+
category: FilterableModelType;
2429
i18nKey: string;
2530
filter: (config: AnyModelConfig) => boolean;
2631
};
2732

28-
export const MODEL_CATEGORIES: Record<string, ModelCategoryData> = {
33+
export const MODEL_CATEGORIES: Record<FilterableModelType, ModelCategoryData> = {
34+
unknown: {
35+
category: 'unknown',
36+
i18nKey: 'common.unknown',
37+
filter: isUnknownModelConfig,
38+
},
2939
main: {
30-
i18nKey: 'model_manager.category.main_models',
40+
category: 'main',
41+
i18nKey: 'modelManager.main',
3142
filter: isNonRefinerMainModelConfig,
3243
},
3344
refiner: {
34-
i18nKey: 'model_manager.category.refiner_models',
45+
category: 'refiner',
46+
i18nKey: 'sdxl.refiner',
3547
filter: isRefinerMainModelModelConfig,
3648
},
3749
lora: {
38-
i18nKey: 'model_manager.category.lora_models',
50+
category: 'lora',
51+
i18nKey: 'modelManager.loraModels',
3952
filter: isLoRAModelConfig,
4053
},
4154
embedding: {
42-
i18nKey: 'model_manager.category.embedding_models',
55+
category: 'embedding',
56+
i18nKey: 'modelManager.textualInversions',
4357
filter: isTIModelConfig,
4458
},
4559
controlnet: {
46-
i18nKey: 'model_manager.category.controlnet_models',
60+
category: 'controlnet',
61+
i18nKey: 'ControlNet',
4762
filter: isControlNetModelConfig,
4863
},
4964
t2i_adapter: {
50-
i18nKey: 'model_manager.category.t2i_adapter_models',
65+
category: 't2i_adapter',
66+
i18nKey: 'common.t2iAdapter',
5167
filter: isT2IAdapterModelConfig,
5268
},
5369
t5_encoder: {
54-
i18nKey: 'model_manager.category.t5_encoder_models',
70+
category: 't5_encoder',
71+
i18nKey: 'modelManager.t5Encoder',
5572
filter: isT5EncoderModelConfig,
5673
},
5774
control_lora: {
58-
i18nKey: 'model_manager.category.control_lora_models',
75+
category: 'control_lora',
76+
i18nKey: 'modelManager.controlLora',
5977
filter: isControlLoRAModelConfig,
6078
},
6179
clip_embed: {
62-
i18nKey: 'model_manager.category.clip_embed_models',
80+
category: 'clip_embed',
81+
i18nKey: 'modelManager.clipEmbed',
6382
filter: isCLIPEmbedModelConfig,
6483
},
65-
spandrel: {
66-
i18nKey: 'model_manager.category.spandrel_image_to_image_models',
84+
spandrel_image_to_image: {
85+
category: 'spandrel_image_to_image',
86+
i18nKey: 'modelManager.spandrelImageToImage',
6787
filter: isSpandrelImageToImageModelConfig,
6888
},
6989
ip_adapter: {
70-
i18nKey: 'model_manager.category.ip_adapter_models',
90+
category: 'ip_adapter',
91+
i18nKey: 'common.ipAdapter',
7192
filter: isIPAdapterModelConfig,
7293
},
7394
vae: {
74-
i18nKey: 'model_manager.category.vae_models',
95+
category: 'vae',
96+
i18nKey: 'VAE',
7597
filter: isVAEModelConfig,
7698
},
7799
clip_vision: {
78-
i18nKey: 'model_manager.category.clip_vision_models',
100+
category: 'clip_vision',
101+
i18nKey: 'CLIP Vision',
79102
filter: isCLIPVisionModelConfig,
80103
},
81104
siglip: {
82-
i18nKey: 'model_manager.category.siglip_models',
105+
category: 'siglip',
106+
i18nKey: 'modelManager.sigLip',
83107
filter: isSigLipModelConfig,
84108
},
85109
flux_redux: {
86-
i18nKey: 'model_manager.category.flux_redux_models',
110+
category: 'flux_redux',
111+
i18nKey: 'modelManager.fluxRedux',
87112
filter: isFluxReduxModelConfig,
88113
},
89-
llava_one_vision: {
90-
i18nKey: 'model_manager.category.llava_one_vision_models',
114+
llava_onevision: {
115+
category: 'llava_onevision',
116+
i18nKey: 'modelManager.llavaOnevision',
91117
filter: isLLaVAModelConfig,
92118
},
93-
unknown: {
94-
i18nKey: 'model_manager.category.unknown_models',
95-
filter: isUnknownModelConfig,
119+
video: {
120+
category: 'video',
121+
i18nKey: 'Video',
122+
filter: isVideoModelConfig,
96123
},
97124
};
98125

126+
export const MODEL_CATEGORIES_AS_LIST = objectEntries(MODEL_CATEGORIES).map(([category, { i18nKey, filter }]) => ({
127+
category,
128+
i18nKey,
129+
filter,
130+
}));
131+
99132
/**
100133
* Mapping of model base to its color
101134
*/

invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library';
22
import { logger } from 'app/logging/logger';
33
import { useAppSelector } from 'app/store/storeHooks';
44
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
5-
import { MODEL_CATEGORIES } from 'features/modelManagerV2/models';
5+
import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
66
import {
77
type FilterableModelType,
88
selectFilteredModelType,
@@ -31,7 +31,7 @@ const ModelList = () => {
3131
const byCategory: { i18nKey: string; configs: AnyModelConfig[] }[] = [];
3232
const total = baseFilteredModelConfigs.length;
3333
let renderedTotal = 0;
34-
for (const { i18nKey, filter } of Object.values(MODEL_CATEGORIES)) {
34+
for (const { i18nKey, filter } of MODEL_CATEGORIES_AS_LIST) {
3535
const configs = baseFilteredModelConfigs.filter(filter);
3636
renderedTotal += configs.length;
3737
byCategory.push({ i18nKey, configs });
Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,49 @@
11
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
22
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
3-
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
3+
import type { ModelCategoryData } from 'features/modelManagerV2/models';
4+
import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
45
import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
5-
import { memo, useCallback, useMemo } from 'react';
6+
import { memo, useCallback } from 'react';
67
import { useTranslation } from 'react-i18next';
78
import { PiFunnelBold } from 'react-icons/pi';
8-
import { objectKeys } from 'tsafe';
99

1010
export const ModelTypeFilter = memo(() => {
1111
const { t } = useTranslation();
1212
const dispatch = useAppDispatch();
13-
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
14-
() => ({
15-
main: t('modelManager.main'),
16-
refiner: t('sdxl.refiner'),
17-
lora: 'LoRA',
18-
embedding: t('modelManager.textualInversions'),
19-
controlnet: 'ControlNet',
20-
vae: 'VAE',
21-
t2i_adapter: t('common.t2iAdapter'),
22-
t5_encoder: t('modelManager.t5Encoder'),
23-
clip_embed: t('modelManager.clipEmbed'),
24-
ip_adapter: t('common.ipAdapter'),
25-
clip_vision: 'CLIP Vision',
26-
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
27-
control_lora: t('modelManager.controlLora'),
28-
siglip: t('modelManager.sigLip'),
29-
flux_redux: t('modelManager.fluxRedux'),
30-
llava_onevision: t('modelManager.llavaOnevision'),
31-
video: t('modelManager.video'),
32-
unknown: t('modelManager.unknown'),
33-
}),
34-
[t]
35-
);
3613
const filteredModelType = useAppSelector(selectFilteredModelType);
3714

38-
const selectModelType = useCallback(
39-
(option: FilterableModelType) => {
40-
dispatch(setFilteredModelType(option));
41-
},
42-
[dispatch]
43-
);
44-
4515
const clearModelType = useCallback(() => {
4616
dispatch(setFilteredModelType(null));
4717
}, [dispatch]);
4818

4919
return (
5020
<Menu>
5121
<MenuButton as={Button} size="sm" leftIcon={<PiFunnelBold />}>
52-
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
22+
{filteredModelType ? t(MODEL_CATEGORIES[filteredModelType].i18nKey) : t('modelManager.allModels')}
5323
</MenuButton>
5424
<MenuList>
5525
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
56-
{objectKeys(MODEL_TYPE_LABELS).map((option) => (
57-
<MenuItem
58-
key={option}
59-
bg={filteredModelType === option ? 'base.700' : 'transparent'}
60-
onClick={selectModelType.bind(null, option)}
61-
>
62-
{MODEL_TYPE_LABELS[option]}
63-
</MenuItem>
26+
{MODEL_CATEGORIES_AS_LIST.map((data) => (
27+
<ModelMenuItem key={data.category} data={data} />
6428
))}
6529
</MenuList>
6630
</Menu>
6731
);
6832
});
6933

7034
ModelTypeFilter.displayName = 'ModelTypeFilter';
35+
36+
const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => {
37+
const { t } = useTranslation();
38+
const dispatch = useAppDispatch();
39+
const filteredModelType = useAppSelector(selectFilteredModelType);
40+
const onClick = useCallback(() => {
41+
dispatch(setFilteredModelType(data.category));
42+
}, [data.category, dispatch]);
43+
return (
44+
<MenuItem bg={filteredModelType === data.category ? 'base.700' : 'transparent'} onClick={onClick}>
45+
{t(data.i18nKey)}
46+
</MenuItem>
47+
);
48+
});
49+
ModelMenuItem.displayName = 'ModelMenuItem';

0 commit comments

Comments
 (0)