Skip to content

Commit eb9a417

Browse files
feat(ui): allow removing individual images from batch
1 parent 3c43351 commit eb9a417

File tree

4 files changed

+128
-69
lines changed

4 files changed

+128
-69
lines changed

invokeai/frontend/web/src/features/dnd/dnd.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ const _setNodeImageFieldImage = buildTypeAndKey('set-node-image-field-image');
221221
export type SetNodeImageFieldImageDndTargetData = DndData<
222222
typeof _setNodeImageFieldImage.type,
223223
typeof _setNodeImageFieldImage.key,
224-
{ fieldIdentifer: FieldIdentifier }
224+
{ fieldIdentifier: FieldIdentifier }
225225
>;
226226
export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDndTargetData, SingleImageDndSourceData> =
227227
{
@@ -236,8 +236,8 @@ export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDn
236236
},
237237
handler: ({ sourceData, targetData, dispatch }) => {
238238
const { imageDTO } = sourceData.payload;
239-
const { fieldIdentifer } = targetData.payload;
240-
setNodeImageFieldImage({ fieldIdentifer, imageDTO, dispatch });
239+
const { fieldIdentifier } = targetData.payload;
240+
setNodeImageFieldImage({ fieldIdentifier, imageDTO, dispatch });
241241
},
242242
};
243243
//#endregion
@@ -247,7 +247,7 @@ const _addImagesToNodeImageFieldCollection = buildTypeAndKey('add-images-to-imag
247247
export type AddImagesToNodeImageFieldCollection = DndData<
248248
typeof _addImagesToNodeImageFieldCollection.type,
249249
typeof _addImagesToNodeImageFieldCollection.key,
250-
{ fieldIdentifer: FieldIdentifier }
250+
{ fieldIdentifier: FieldIdentifier }
251251
>;
252252
export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
253253
AddImagesToNodeImageFieldCollection,
@@ -267,7 +267,7 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
267267
return;
268268
}
269269

270-
const { fieldIdentifer } = targetData.payload;
270+
const { fieldIdentifier } = targetData.payload;
271271
const imageDTOs: ImageDTO[] = [];
272272

273273
if (singleImageDndSource.typeGuard(sourceData)) {
@@ -276,7 +276,7 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
276276
imageDTOs.push(...sourceData.payload.imageDTOs);
277277
}
278278

279-
addImagesToNodeImageFieldCollectionAction({ fieldIdentifer, imageDTOs, dispatch, getState });
279+
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
280280
},
281281
};
282282
//#endregion

invokeai/frontend/web/src/features/imageActions/actions.ts

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,35 +69,59 @@ export const setUpscaleInitialImage = (arg: { imageDTO: ImageDTO; dispatch: AppD
6969

7070
export const setNodeImageFieldImage = (arg: {
7171
imageDTO: ImageDTO;
72-
fieldIdentifer: FieldIdentifier;
72+
fieldIdentifier: FieldIdentifier;
7373
dispatch: AppDispatch;
7474
}) => {
75-
const { imageDTO, fieldIdentifer, dispatch } = arg;
76-
dispatch(fieldImageValueChanged({ ...fieldIdentifer, value: imageDTO }));
75+
const { imageDTO, fieldIdentifier, dispatch } = arg;
76+
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
7777
};
7878

7979
export const addImagesToNodeImageFieldCollectionAction = (arg: {
8080
imageDTOs: ImageDTO[];
81-
fieldIdentifer: FieldIdentifier;
81+
fieldIdentifier: FieldIdentifier;
8282
dispatch: AppDispatch;
8383
getState: () => RootState;
8484
}) => {
85-
const { imageDTOs, fieldIdentifer, dispatch, getState } = arg;
85+
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
8686
const fieldInputInstance = selectFieldInputInstance(
8787
selectNodesSlice(getState()),
88-
fieldIdentifer.nodeId,
89-
fieldIdentifer.fieldName
88+
fieldIdentifier.nodeId,
89+
fieldIdentifier.fieldName
9090
);
9191

9292
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
93-
log.warn({ fieldIdentifer }, 'Attempted to add images to a non-image field collection');
93+
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
9494
return;
9595
}
9696

9797
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
9898
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
9999
const uniqueImages = uniqBy(images, 'image_name');
100-
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifer, value: uniqueImages }));
100+
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
101+
};
102+
103+
export const removeImageFromNodeImageFieldCollectionAction = (arg: {
104+
imageName: string;
105+
fieldIdentifier: FieldIdentifier;
106+
dispatch: AppDispatch;
107+
getState: () => RootState;
108+
}) => {
109+
const { imageName, fieldIdentifier, dispatch, getState } = arg;
110+
const fieldInputInstance = selectFieldInputInstance(
111+
selectNodesSlice(getState()),
112+
fieldIdentifier.nodeId,
113+
fieldIdentifier.fieldName
114+
);
115+
116+
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
117+
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
118+
return;
119+
}
120+
121+
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
122+
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
123+
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
124+
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
101125
};
102126

103127
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,70 @@
11
import type { SystemStyleObject } from '@invoke-ai/ui-library';
2-
import { Flex, Grid, GridItem, IconButton } from '@invoke-ai/ui-library';
3-
import { useAppDispatch } from 'app/store/storeHooks';
2+
import { Flex, Grid, GridItem } from '@invoke-ai/ui-library';
3+
import { useAppStore } from 'app/store/nanostores/store';
4+
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
45
import { UploadMultipleImageButton } from 'common/hooks/useImageUploadButton';
56
import type { AddImagesToNodeImageFieldCollection } from 'features/dnd/dnd';
67
import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
78
import { DndDropTarget } from 'features/dnd/DndDropTarget';
8-
import { DndImageFromImageName } from 'features/dnd/DndImageFromImageName';
9+
import { DndImage } from 'features/dnd/DndImage';
10+
import { DndImageIcon } from 'features/dnd/DndImageIcon';
11+
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
912
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
1013
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
1114
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
1215
import { memo, useCallback, useMemo } from 'react';
1316
import { useTranslation } from 'react-i18next';
14-
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
17+
import { PiArrowCounterClockwiseBold, PiExclamationMarkBold } from 'react-icons/pi';
18+
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
1519
import type { ImageDTO } from 'services/api/types';
1620

1721
import type { FieldComponentProps } from './types';
1822

1923
const sx = {
24+
borderWidth: 1,
2025
'&[data-error=true]': {
2126
borderColor: 'error.500',
2227
borderStyle: 'solid',
23-
borderWidth: 1,
2428
},
2529
} satisfies SystemStyleObject;
2630

2731
export const ImageFieldCollectionInputComponent = memo(
2832
(props: FieldComponentProps<ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate>) => {
2933
const { t } = useTranslation();
3034
const { nodeId, field } = props;
31-
const dispatch = useAppDispatch();
32-
const isInvalid = useFieldIsInvalid(nodeId, field.name);
35+
const store = useAppStore();
3336

34-
const onReset = useCallback(() => {
35-
dispatch(
36-
fieldImageCollectionValueChanged({
37-
nodeId,
38-
fieldName: field.name,
39-
value: [],
40-
})
41-
);
42-
}, [dispatch, field.name, nodeId]);
37+
const isInvalid = useFieldIsInvalid(nodeId, field.name);
4338

4439
const dndTargetData = useMemo<AddImagesToNodeImageFieldCollection>(
45-
() => addImagesToNodeImageFieldCollectionDndTarget.getData({ fieldIdentifer: { nodeId, fieldName: field.name } }),
40+
() =>
41+
addImagesToNodeImageFieldCollectionDndTarget.getData({ fieldIdentifier: { nodeId, fieldName: field.name } }),
4642
[field, nodeId]
4743
);
4844

4945
const onUpload = useCallback(
5046
(imageDTOs: ImageDTO[]) => {
51-
dispatch(
47+
store.dispatch(
5248
fieldImageCollectionValueChanged({
5349
nodeId,
5450
fieldName: field.name,
5551
value: imageDTOs,
5652
})
5753
);
5854
},
59-
[dispatch, field.name, nodeId]
55+
[store, nodeId, field.name]
56+
);
57+
58+
const onRemoveImage = useCallback(
59+
(imageName: string) => {
60+
removeImageFromNodeImageFieldCollectionAction({
61+
imageName,
62+
fieldIdentifier: { nodeId, fieldName: field.name },
63+
dispatch: store.dispatch,
64+
getState: store.getState,
65+
});
66+
},
67+
[field.name, nodeId, store.dispatch, store.getState]
6068
);
6169

6270
return (
@@ -80,33 +88,23 @@ export const ImageFieldCollectionInputComponent = memo(
8088
/>
8189
)}
8290
{field.value && field.value.length > 0 && (
83-
<>
84-
<Grid
85-
className="nopan"
86-
borderRadius="base"
87-
w="full"
88-
h="full"
89-
templateColumns={`repeat(${Math.min(field.value.length, 3)}, 1fr)`}
90-
gap={1}
91-
sx={sx}
92-
data-error={isInvalid}
93-
p={1}
94-
>
95-
{field.value.map(({ image_name }) => (
96-
<GridItem key={image_name}>
97-
<DndImageFromImageName imageName={image_name} asThumbnail />
98-
</GridItem>
99-
))}
100-
</Grid>
101-
<IconButton
102-
aria-label="reset"
103-
icon={<PiArrowCounterClockwiseBold />}
104-
position="absolute"
105-
top={0}
106-
insetInlineEnd={0}
107-
onClick={onReset}
108-
/>
109-
</>
91+
<Grid
92+
className="nopan"
93+
borderRadius="base"
94+
w="full"
95+
h="full"
96+
templateColumns="repeat(3, 1fr)"
97+
gap={1}
98+
sx={sx}
99+
data-error={isInvalid}
100+
p={1}
101+
>
102+
{field.value.map(({ image_name }) => (
103+
<GridItem key={image_name} position="relative">
104+
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
105+
</GridItem>
106+
))}
107+
</Grid>
110108
)}
111109
<DndDropTarget
112110
dndTarget={addImagesToNodeImageFieldCollectionDndTarget}
@@ -119,3 +117,37 @@ export const ImageFieldCollectionInputComponent = memo(
119117
);
120118

121119
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
120+
121+
const ImageGridItemContent = memo(
122+
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
123+
const query = useGetImageDTOQuery(imageName);
124+
const onClickRemove = useCallback(() => {
125+
onRemoveImage(imageName);
126+
}, [imageName, onRemoveImage]);
127+
128+
if (query.isLoading) {
129+
return <IAINoContentFallbackWithSpinner />;
130+
}
131+
132+
if (!query.data) {
133+
return <IAINoContentFallback icon={<PiExclamationMarkBold />} />;
134+
}
135+
136+
return (
137+
<>
138+
<DndImage imageDTO={query.data} asThumbnail />
139+
<DndImageIcon
140+
onClick={onClickRemove}
141+
icon={<PiArrowCounterClockwiseBold />}
142+
tooltip="Reset Image"
143+
position="absolute"
144+
flexDir="column"
145+
top={1}
146+
insetInlineEnd={1}
147+
gap={1}
148+
/>
149+
</>
150+
);
151+
}
152+
);
153+
ImageGridItemContent.displayName = 'ImageGridItemContent';

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInst
3838
const dndTargetData = useMemo<SetNodeImageFieldImageDndTargetData>(
3939
() =>
4040
setNodeImageFieldImageDndTarget.getData(
41-
{ fieldIdentifer: { nodeId, fieldName: field.name } },
41+
{ fieldIdentifier: { nodeId, fieldName: field.name } },
4242
field.value?.image_name
4343
),
4444
[field, nodeId]
@@ -85,13 +85,16 @@ const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInst
8585
{imageDTO && (
8686
<>
8787
<DndImage imageDTO={imageDTO} minW={8} minH={8} />
88-
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
89-
<DndImageIcon
90-
onClick={handleReset}
91-
icon={imageDTO ? <PiArrowCounterClockwiseBold /> : undefined}
92-
tooltip="Reset Image"
93-
/>
94-
</Flex>
88+
<DndImageIcon
89+
onClick={handleReset}
90+
icon={imageDTO ? <PiArrowCounterClockwiseBold /> : undefined}
91+
tooltip="Reset Image"
92+
position="absolute"
93+
flexDir="column"
94+
top={1}
95+
insetInlineEnd={1}
96+
gap={1}
97+
/>
9598
</>
9699
)}
97100
<DndDropTarget

0 commit comments

Comments
 (0)