Skip to content

Commit 03117ac

Browse files
committed
Extract explanations to its own query
1 parent b666041 commit 03117ac

File tree

15 files changed

+111
-77
lines changed

15 files changed

+111
-77
lines changed

web_ui/packages/core/src/requests/query-keys.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ const getSelectedMediaItemQueryKeys = () => {
118118
taskId?: string,
119119
roiId?: string
120120
) => [...commonKey(mediaIdentifier), taskId, roiId, `${prefix}-predictions`, predictionCache],
121+
122+
EXPLANATIONS: (
123+
datasetIdentifier: DatasetIdentifier,
124+
mediaIdentifier: MediaIdentifier | undefined,
125+
taskId?: string
126+
) => ['explanations', datasetIdentifier, mediaIdentifier, taskId],
127+
121128
SELECTED: (mediaIdentifier: MediaIdentifier | undefined, taskId?: string) => [
122129
...commonKey(mediaIdentifier),
123130
taskId,

web_ui/src/core/annotations/services/prediction-service.interface.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ import { Explanation } from '../prediction.interface';
66

77
export interface PredictionResult {
88
annotations: ReadonlyArray<Annotation>;
9-
maps: Explanation[];
9+
}
10+
11+
export interface ExplanationResult {
12+
maps: ReadonlyArray<Explanation>;
1013
}
1114

1215
export interface InferenceServerStatusResult {

web_ui/src/pages/annotator/providers/annotator-provider/annotator-provider.component.tsx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ export const AnnotatorProvider = ({ children }: AnnotatorProviderProps): JSX.Ele
6969

7070
const [activeTool, setActiveTool] = useState<ToolType>(() => defaultToolForProject(activeDomains));
7171

72-
const { selectedMediaItem, predictionsQuery } = useSelectedMediaItem();
72+
const { selectedMediaItem, predictionsQuery, explanationsQuery } = useSelectedMediaItem();
7373
const isTaskChainSelectedClassification = isTaskChainDomainSelected(DOMAIN.CLASSIFICATION);
7474

7575
const initialPredictionAnnotations = predictionsQuery.data?.annotations;
@@ -79,6 +79,7 @@ export const AnnotatorProvider = ({ children }: AnnotatorProviderProps): JSX.Ele
7979
userProjectSettings,
8080
isTaskChainSelectedClassification
8181
);
82+
const explanations = explanationsQuery.data || [];
8283

8384
const { undoRedoActions, ...userAnnotationScene } = useAnnotationSceneState(
8485
initialAnnotations,
@@ -109,7 +110,7 @@ export const AnnotatorProvider = ({ children }: AnnotatorProviderProps): JSX.Ele
109110
settings={userProjectSettings}
110111
userAnnotationScene={userAnnotationScene}
111112
initPredictions={initialPredictionAnnotations}
112-
explanations={selectedMediaItem?.predictions?.maps || EMPTY_EXPLANATION}
113+
explanations={explanations || EMPTY_EXPLANATION}
113114
>
114115
<SubmitAnnotationsProvider
115116
settings={userProjectSettings}

web_ui/src/pages/annotator/providers/annotator-provider/use-initial-annotations.hook.test.tsx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ const geSelectedMediaItem = (predictions: Annotation[] = [], annotations: Annota
3333
...getMockedImageMediaItem({}),
3434
image: getMockedImage(),
3535
annotations,
36-
predictions: { annotations: predictions, maps: [] },
36+
predictions: { annotations: predictions },
37+
explanations: [],
3738
});
3839
const getInitialPredictionConfig = (isEnabled = true) => ({
3940
[FEATURES_KEYS.INITIAL_PREDICTION]: {

web_ui/src/pages/annotator/providers/prediction-provider/prediction-provider.component.tsx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ export const PredictionProvider = ({
326326
selectedInput,
327327
taskId: String(selectedTask?.id),
328328
enabled: isPredictionsQueryEnabled,
329-
onSuccess: runWhenNotDrawing(({ annotations: newRawPredictions, maps: newMaps }: PredictionResult) => {
329+
onSuccess: runWhenNotDrawing(({ annotations: newRawPredictions }: PredictionResult) => {
330330
const selectedPredictions = selectAnnotations(newRawPredictions, userSceneSelectedInputs);
331331

332332
if (isEmpty(userAnnotations)) {
@@ -348,7 +348,6 @@ export const PredictionProvider = ({
348348
canUpdatePrediction && userAnnotationScene.updateAnnotation(newPrediction);
349349
}
350350

351-
setExplanations(newMaps);
352351
setRawPredictions(selectedPredictions);
353352

354353
// Optionally update video timeline predictions

web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.test.tsx

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ describe('usePredictionsRoiQuery', () => {
135135
});
136136

137137
await waitFor(() => {
138-
expect(result.current.data).toEqual({ annotations: [], maps: [] });
138+
expect(result.current.data).toEqual({ annotations: [] });
139139
});
140140
});
141141

@@ -149,7 +149,6 @@ describe('usePredictionsRoiQuery', () => {
149149
await waitFor(() => {
150150
expect(result.current.data).toEqual({
151151
annotations: [],
152-
maps: [],
153152
});
154153
});
155154

@@ -236,7 +235,7 @@ describe('usePredictionsRoiQuery', () => {
236235
});
237236

238237
await waitFor(() => {
239-
expect(result.current.data).toEqual({ annotations: [], maps: [] });
238+
expect(result.current.data).toEqual({ annotations: [] });
240239
});
241240
});
242241
});

web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.ts

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@ import { useQuery } from '@tanstack/react-query';
99
import { AxiosError } from 'axios';
1010

1111
import { TaskChainInput } from '../../../../core/annotations/annotation.interface';
12-
import {
13-
PredictionCache,
14-
PredictionMode,
15-
PredictionResult,
16-
} from '../../../../core/annotations/services/prediction-service.interface';
12+
import { PredictionMode, PredictionResult } from '../../../../core/annotations/services/prediction-service.interface';
1713
import { getPredictionCache } from '../../../../core/annotations/services/utils';
1814
import { usePrevious } from '../../../../hooks/use-previous/use-previous.hook';
1915
import { useProject } from '../../../project-details/providers/project-provider/project-provider.component';
@@ -45,21 +41,17 @@ export const usePredictionsRoiQuery = ({
4541
const { inferenceService } = useApplicationServices();
4642

4743
const isValidRoi = roiId !== undefined && prevRoi !== roiId;
48-
4944
const isQueryEnabled = enabled && isValidRoi;
5045

5146
const predictionMode = isActiveLearningMode ? PredictionMode.AUTO : PredictionMode.ONLINE;
5247
const predictionCache = getPredictionCache(predictionMode);
53-
const isPredictionCacheNever = predictionCache === PredictionCache.NEVER;
5448

5549
const handleSuccessRef = useRef(onSuccess);
5650

5751
useLayoutEffect(() => {
5852
handleSuccessRef.current = onSuccess;
5953
}, [onSuccess]);
6054

61-
// TODO: extract explanation, combine with other prediction query
62-
6355
const query = useQuery<PredictionResult, AxiosError>({
6456
queryKey: QUERY_KEYS.SELECTED_MEDIA_ITEM.PREDICTIONS(
6557
selectedMediaItem?.identifier,
@@ -69,39 +61,25 @@ export const usePredictionsRoiQuery = ({
6961
roiId
7062
),
7163
queryFn: async ({ signal }) => {
72-
if (isQueryEnabled === false) {
73-
return { maps: [], annotations: [] };
64+
if (!isQueryEnabled || !selectedMediaItem) {
65+
return { annotations: [] };
7466
}
7567

76-
if (selectedMediaItem === undefined) {
77-
return { maps: [], annotations: [] };
78-
}
79-
80-
const explainPromise = isPredictionCacheNever
81-
? inferenceService.getExplanations(datasetIdentifier, selectedMediaItem, taskId, selectedInput, signal)
82-
: Promise.resolve([]);
83-
84-
return Promise.allSettled([
85-
inferenceService.getPredictions(
86-
datasetIdentifier,
87-
project.labels,
88-
selectedMediaItem,
89-
predictionCache,
90-
taskId,
91-
selectedInput,
92-
signal
93-
),
94-
explainPromise,
95-
]).then(([annotationsResponse, mapsResponse]) => {
96-
const maps = mapsResponse.status === 'fulfilled' ? mapsResponse.value : [];
97-
const annotations = annotationsResponse.status === 'fulfilled' ? annotationsResponse.value : [];
68+
const annotations = await inferenceService.getPredictions(
69+
datasetIdentifier,
70+
project.labels,
71+
selectedMediaItem,
72+
predictionCache,
73+
taskId,
74+
selectedInput,
75+
signal
76+
);
9877

99-
handleSuccessRef.current !== undefined && handleSuccessRef.current({ annotations, maps });
78+
handleSuccessRef.current?.({ annotations });
10079

101-
return { annotations, maps };
102-
});
80+
return { annotations };
10381
},
104-
initialData: { maps: [], annotations: [] },
82+
initialData: { annotations: [] },
10583
enabled: isQueryEnabled,
10684
});
10785

web_ui/src/pages/annotator/providers/selected-media-item-provider/default-selected-media-item-provider.component.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import QUERY_KEYS from '@geti/core/src/requests/query-keys';
77
import { useQuery, UseQueryResult } from '@tanstack/react-query';
88
import { noop } from 'lodash-es';
99

10+
import { ExplanationResult } from '../../../../core/annotations/services/inference-service.interface';
1011
import { PredictionResult } from '../../../../core/annotations/services/prediction-service.interface';
1112
import { SelectedMediaItemContext, SelectedMediaItemProps } from './selected-media-item-provider.component';
1213
import { SelectedMediaItem } from './selected-media-item.interface';
@@ -40,6 +41,7 @@ export const DefaultSelectedMediaItemProvider = ({
4041
selectedMediaItemQuery,
4142
setSelectedMediaItem: noop,
4243
predictionsQuery: { isLoading: false } as UseQueryResult<PredictionResult>,
44+
explanationsQuery: { isLoading: false } as UseQueryResult<ExplanationResult>,
4345
};
4446

4547
return <SelectedMediaItemContext.Provider value={value}>{children}</SelectedMediaItemContext.Provider>;

web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item-provider.component.tsx

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import { AxiosError } from 'axios';
1212
import { isEmpty, isEqual } from 'lodash-es';
1313

1414
import { Annotation } from '../../../../core/annotations/annotation.interface';
15+
import { ExplanationResult } from '../../../../core/annotations/services/inference-service.interface';
1516
import { PredictionMode, PredictionResult } from '../../../../core/annotations/services/prediction-service.interface';
1617
import { InferenceModel } from '../../../../core/annotations/services/visual-prompt-service';
1718
import { MediaItem } from '../../../../core/media/media.interface';
@@ -31,6 +32,7 @@ import { hasEmptyLabels } from '../../utils';
3132
import { useTask } from '../task-provider/task-provider.component';
3233
import { SelectedMediaItem } from './selected-media-item.interface';
3334
import { useAnnotationsQuery } from './use-annotations-query.hook';
35+
import { useExplanationsQuery } from './use-explanation-query.hook';
3436
import { useLoadImageQuery } from './use-load-image-query.hook';
3537
import { usePredictionsQuery } from './use-predictions-query.hook';
3638
import { useSelectedInferenceModel } from './use-selected-inference-model';
@@ -43,6 +45,7 @@ import {
4345

4446
export interface SelectedMediaItemProps {
4547
predictionsQuery: UseQueryResult<PredictionResult>;
48+
explanationsQuery: UseQueryResult<ExplanationResult>;
4649
selectedMediaItem: SelectedMediaItem | undefined;
4750
selectedMediaItemQuery: UseQueryResult<SelectedMediaItem>;
4851
setSelectedMediaItem: (media: MediaItem | undefined) => void;
@@ -87,7 +90,7 @@ const useIsSuggestPredictionEnabled = (projectIdentifier: ProjectIdentifier): bo
8790
};
8891

8992
/**
90-
* Load either online or auto predicitons based on the annotator mode
93+
* Load either online or auto predictions based on the annotator mode
9194
* When the user switches between both modes we don't want to refetch predictions,
9295
* so in both cases we will keep the queries mounted
9396
*/
@@ -169,6 +172,10 @@ export const SelectedMediaItemProvider = ({ children }: SelectedMediaItemProvide
169172
});
170173

171174
const predictionsQuery = usePredictionsQueryBasedOnAnnotatorMode(mediaItem);
175+
const explanationsQuery = useExplanationsQuery({
176+
datasetIdentifier,
177+
mediaItem,
178+
});
172179

173180
const selectedMediaItemQueryKey = [
174181
...QUERY_KEYS.SELECTED_MEDIA_ITEM.SELECTED(pendingMediaItem?.identifier, selectedTask?.id),
@@ -201,7 +208,7 @@ export const SelectedMediaItemProvider = ({ children }: SelectedMediaItemProvide
201208
}
202209
}
203210

204-
const predictionsData = predictionsQuery.data ?? { maps: [], annotations: [] };
211+
const predictionsData = predictionsQuery.data ?? { annotations: [] };
205212

206213
resolve([imageQuery.data, annotationsQuery.data, predictionsData]);
207214
}
@@ -266,6 +273,7 @@ export const SelectedMediaItemProvider = ({ children }: SelectedMediaItemProvide
266273

267274
const value: SelectedMediaItemProps = {
268275
predictionsQuery,
276+
explanationsQuery,
269277
selectedMediaItem,
270278
selectedMediaItemQuery,
271279
setSelectedMediaItem: setPendingMediaItem,

web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item.interface.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
// LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE
33

44
import { Annotation } from '../../../../core/annotations/annotation.interface';
5+
import { ExplanationResult } from '../../../../core/annotations/services/inference-service.interface';
56
import { PredictionResult } from '../../../../core/annotations/services/prediction-service.interface';
67
import { MediaItem } from '../../../../core/media/media.interface';
78

89
export type SelectedMediaItem = MediaItem & {
910
image: ImageData;
1011
annotations: Annotation[];
1112
predictions?: PredictionResult;
13+
explanations?: ExplanationResult;
1214
};

0 commit comments

Comments
 (0)