|
| 1 | +import type { SystemStyleObject } from '@invoke-ai/ui-library'; |
| 2 | +import { Button, Checkbox, Flex, Menu, MenuButton, MenuItem, MenuList, Text } from '@invoke-ai/ui-library'; |
| 3 | +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; |
| 4 | +import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; |
| 5 | +import { |
| 6 | + modelSelectionChanged, |
| 7 | + selectFilteredModelType, |
| 8 | + selectSearchTerm, |
| 9 | + selectSelectedModelKeys, |
| 10 | +} from 'features/modelManagerV2/store/modelManagerV2Slice'; |
| 11 | +import { t } from 'i18next'; |
| 12 | +import { memo, useCallback, useMemo } from 'react'; |
| 13 | +import { PiCaretDownBold, PiTrashSimpleBold } from 'react-icons/pi'; |
| 14 | +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; |
| 15 | +import type { AnyModelConfig } from 'services/api/types'; |
| 16 | + |
| 17 | +import { useBulkDeleteModal } from './ModelList'; |
| 18 | + |
| 19 | +const ModelListBulkActionsSx: SystemStyleObject = { |
| 20 | + alignItems: 'center', |
| 21 | + justifyContent: 'space-between', |
| 22 | + width: '100%', |
| 23 | +}; |
| 24 | + |
| 25 | +type ModelListBulkActionsProps = { |
| 26 | + sx?: SystemStyleObject; |
| 27 | +}; |
| 28 | + |
| 29 | +export const ModelListBulkActions = memo(({ sx }: ModelListBulkActionsProps) => { |
| 30 | + const dispatch = useAppDispatch(); |
| 31 | + const filteredModelType = useAppSelector(selectFilteredModelType); |
| 32 | + const selectedModelKeys = useAppSelector(selectSelectedModelKeys); |
| 33 | + const searchTerm = useAppSelector(selectSearchTerm); |
| 34 | + const { data } = useGetModelConfigsQuery(); |
| 35 | + const bulkDeleteModal = useBulkDeleteModal(); |
| 36 | + |
| 37 | + const handleBulkDelete = useCallback(() => { |
| 38 | + bulkDeleteModal.open(); |
| 39 | + }, [bulkDeleteModal]); |
| 40 | + |
| 41 | + // Calculate displayed (filtered) model keys |
| 42 | + const displayedModelKeys = useMemo(() => { |
| 43 | + const modelConfigs = modelConfigsAdapterSelectors.selectAll(data ?? { ids: [], entities: {} }); |
| 44 | + const filteredModels = modelsFilter(modelConfigs, searchTerm, filteredModelType); |
| 45 | + return filteredModels.map((m) => m.key); |
| 46 | + }, [data, searchTerm, filteredModelType]); |
| 47 | + |
| 48 | + const { allSelected, someSelected } = useMemo(() => { |
| 49 | + if (displayedModelKeys.length === 0) { |
| 50 | + return { allSelected: false, someSelected: false }; |
| 51 | + } |
| 52 | + const selectedSet = new Set(selectedModelKeys); |
| 53 | + const displayedSelectedCount = displayedModelKeys.filter((key) => selectedSet.has(key)).length; |
| 54 | + return { |
| 55 | + allSelected: displayedSelectedCount === displayedModelKeys.length, |
| 56 | + someSelected: displayedSelectedCount > 0 && displayedSelectedCount < displayedModelKeys.length, |
| 57 | + }; |
| 58 | + }, [displayedModelKeys, selectedModelKeys]); |
| 59 | + |
| 60 | + const handleToggleAll = useCallback(() => { |
| 61 | + if (allSelected) { |
| 62 | + // Deselect all displayed models |
| 63 | + const displayedSet = new Set(displayedModelKeys); |
| 64 | + const newSelection = selectedModelKeys.filter((key) => !displayedSet.has(key)); |
| 65 | + dispatch(modelSelectionChanged(newSelection)); |
| 66 | + } else { |
| 67 | + // Select all displayed models (merge with existing selection) |
| 68 | + const selectedSet = new Set(selectedModelKeys); |
| 69 | + displayedModelKeys.forEach((key) => selectedSet.add(key)); |
| 70 | + dispatch(modelSelectionChanged(Array.from(selectedSet))); |
| 71 | + } |
| 72 | + }, [allSelected, displayedModelKeys, selectedModelKeys, dispatch]); |
| 73 | + |
| 74 | + const selectionCount = selectedModelKeys.length; |
| 75 | + |
| 76 | + return ( |
| 77 | + <Flex sx={{ ...ModelListBulkActionsSx, sx }}> |
| 78 | + <Checkbox |
| 79 | + isChecked={allSelected} |
| 80 | + isIndeterminate={someSelected} |
| 81 | + onChange={handleToggleAll} |
| 82 | + isDisabled={displayedModelKeys.length === 0} |
| 83 | + aria-label={t('modelManager.selectAll')} |
| 84 | + > |
| 85 | + <Text variant="subtext1" color="base.400"> |
| 86 | + {t('modelManager.selectAll')} |
| 87 | + </Text> |
| 88 | + </Checkbox> |
| 89 | + |
| 90 | + <Flex alignItems="center" gap={4}> |
| 91 | + <Text variant="subtext" color="base.400"> |
| 92 | + {selectionCount} {t('common.selected')} |
| 93 | + </Text> |
| 94 | + <Menu placement="bottom-end"> |
| 95 | + <MenuButton |
| 96 | + as={Button} |
| 97 | + disabled={selectionCount === 0} |
| 98 | + size="sm" |
| 99 | + rightIcon={<PiCaretDownBold />} |
| 100 | + flexShrink={0} |
| 101 | + variant="outline" |
| 102 | + > |
| 103 | + {t('modelManager.actions')} |
| 104 | + </MenuButton> |
| 105 | + <MenuList> |
| 106 | + <MenuItem icon={<PiTrashSimpleBold />} onClick={handleBulkDelete} color="error.300"> |
| 107 | + {t('modelManager.deleteModels', { count: selectionCount })} |
| 108 | + </MenuItem> |
| 109 | + </MenuList> |
| 110 | + </Menu> |
| 111 | + </Flex> |
| 112 | + </Flex> |
| 113 | + ); |
| 114 | +}); |
| 115 | + |
| 116 | +ModelListBulkActions.displayName = 'ModelListBulkActions'; |
| 117 | + |
| 118 | +const modelsFilter = <T extends AnyModelConfig>( |
| 119 | + data: T[], |
| 120 | + nameFilter: string, |
| 121 | + filteredModelType: FilterableModelType | null |
| 122 | +): T[] => { |
| 123 | + return data.filter((model) => { |
| 124 | + const matchesFilter = |
| 125 | + model.name.toLowerCase().includes(nameFilter.toLowerCase()) || |
| 126 | + model.base.toLowerCase().includes(nameFilter.toLowerCase()) || |
| 127 | + model.type.toLowerCase().includes(nameFilter.toLowerCase()) || |
| 128 | + model.description?.toLowerCase().includes(nameFilter.toLowerCase()) || |
| 129 | + model.format.toLowerCase().includes(nameFilter.toLowerCase()); |
| 130 | + |
| 131 | + const matchesType = getMatchesType(model, filteredModelType); |
| 132 | + |
| 133 | + return matchesFilter && matchesType; |
| 134 | + }); |
| 135 | +}; |
| 136 | + |
| 137 | +const getMatchesType = (modelConfig: AnyModelConfig, filteredModelType: FilterableModelType | null): boolean => { |
| 138 | + if (filteredModelType === 'refiner') { |
| 139 | + return modelConfig.base === 'sdxl-refiner'; |
| 140 | + } |
| 141 | + |
| 142 | + if (filteredModelType === 'main' && modelConfig.base === 'sdxl-refiner') { |
| 143 | + return false; |
| 144 | + } |
| 145 | + |
| 146 | + return filteredModelType ? modelConfig.type === filteredModelType : true; |
| 147 | +}; |
0 commit comments