diff --git a/web_ui/packages/core/src/requests/query-keys.ts b/web_ui/packages/core/src/requests/query-keys.ts index 1f6a56c44b..f7bf249e52 100644 --- a/web_ui/packages/core/src/requests/query-keys.ts +++ b/web_ui/packages/core/src/requests/query-keys.ts @@ -118,6 +118,13 @@ const getSelectedMediaItemQueryKeys = () => { taskId?: string, roiId?: string ) => [...commonKey(mediaIdentifier), taskId, roiId, `${prefix}-predictions`, predictionCache], + + EXPLANATIONS: ( + datasetIdentifier: DatasetIdentifier, + mediaIdentifier: MediaIdentifier | undefined, + taskId?: string + ) => [...commonKey(mediaIdentifier), 'explanations', datasetIdentifier, mediaIdentifier, taskId], + SELECTED: (mediaIdentifier: MediaIdentifier | undefined, taskId?: string) => [ ...commonKey(mediaIdentifier), taskId, diff --git a/web_ui/src/core/annotations/services/inference-service.interface.ts b/web_ui/src/core/annotations/services/inference-service.interface.ts index aa7edf671a..4e00067362 100644 --- a/web_ui/src/core/annotations/services/inference-service.interface.ts +++ b/web_ui/src/core/annotations/services/inference-service.interface.ts @@ -7,12 +7,16 @@ import { ProjectIdentifier } from '../../projects/core.interface'; import { DatasetIdentifier } from '../../projects/dataset.interface'; import { Annotation, TaskChainInput } from '../annotation.interface'; import { Explanation } from '../prediction.interface'; -import { InferenceServerStatusResult, PredictionCache, PredictionMode } from './prediction-service.interface'; +import { PredictionCache, PredictionMode } from './prediction-service.interface'; import { VideoPaginationOptions } from './video-pagination-options.interface'; export type InferenceResult = ReadonlyArray; export type ExplanationResult = Explanation[]; +export interface InferenceServerStatusResult { + isInferenceServerReady: boolean; +} + export interface InferenceService { getTestPredictions: ( projectIdentifier: ProjectIdentifier, diff --git a/web_ui/src/core/annotations/services/inference-service/api-inference-service.ts b/web_ui/src/core/annotations/services/inference-service/api-inference-service.ts index 24298b3756..bec94833ba 100644 --- a/web_ui/src/core/annotations/services/inference-service/api-inference-service.ts +++ b/web_ui/src/core/annotations/services/inference-service/api-inference-service.ts @@ -30,8 +30,13 @@ import { PipelineServerStatusDTO, } from '../../dtos/prediction.interface'; import { Rect } from '../../shapes.interface'; -import { ExplanationResult, InferenceResult, InferenceService } from '../inference-service.interface'; -import { InferenceServerStatusResult, PredictionCache, PredictionMode } from '../prediction-service.interface'; +import { + ExplanationResult, + InferenceResult, + InferenceServerStatusResult, + InferenceService, +} from '../inference-service.interface'; +import { PredictionCache, PredictionMode } from '../prediction-service.interface'; import { buildPredictionParams, getExplanations as convertExplanations, diff --git a/web_ui/src/core/annotations/services/prediction-service.interface.ts b/web_ui/src/core/annotations/services/prediction-service.interface.ts index bfc9817e62..f4c66eec1e 100644 --- a/web_ui/src/core/annotations/services/prediction-service.interface.ts +++ b/web_ui/src/core/annotations/services/prediction-service.interface.ts @@ -2,15 +2,9 @@ // LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE import { Annotation } from '../annotation.interface'; -import { Explanation } from '../prediction.interface'; export interface PredictionResult { annotations: ReadonlyArray; - maps: Explanation[]; -} - -export interface InferenceServerStatusResult { - isInferenceServerReady: boolean; } export enum PredictionMode { diff --git a/web_ui/src/pages/annotator/annotation/layers/layers-factory.test.tsx b/web_ui/src/pages/annotator/annotation/layers/layers-factory.test.tsx index af0e39a2ee..b2e92b8f4b 100644 --- a/web_ui/src/pages/annotator/annotation/layers/layers-factory.test.tsx +++ b/web_ui/src/pages/annotator/annotation/layers/layers-factory.test.tsx @@ -59,6 +59,7 @@ describe('LayersFactory', () => { selectedMediaItem, selectedMediaItemQuery: { isLoading: false }, predictionsQuery: { data: undefined }, + explanationsQuery: { data: undefined }, } as SelectedMediaItemProps); const response = await annotatorRender( diff --git a/web_ui/src/pages/annotator/providers/annotator-provider/annotator-provider.component.tsx b/web_ui/src/pages/annotator/providers/annotator-provider/annotator-provider.component.tsx index 12cbdcb15e..450426b6aa 100644 --- a/web_ui/src/pages/annotator/providers/annotator-provider/annotator-provider.component.tsx +++ b/web_ui/src/pages/annotator/providers/annotator-provider/annotator-provider.component.tsx @@ -109,7 +109,7 @@ export const AnnotatorProvider = ({ children }: AnnotatorProviderProps): JSX.Ele settings={userProjectSettings} userAnnotationScene={userAnnotationScene} initPredictions={initialPredictionAnnotations} - explanations={selectedMediaItem?.predictions?.maps || EMPTY_EXPLANATION} + explanations={selectedMediaItem?.explanations || EMPTY_EXPLANATION} > ({ [FEATURES_KEYS.INITIAL_PREDICTION]: { diff --git a/web_ui/src/pages/annotator/providers/prediction-provider/prediction-provider.component.tsx b/web_ui/src/pages/annotator/providers/prediction-provider/prediction-provider.component.tsx index 1e4df08106..d41a09c410 100644 --- a/web_ui/src/pages/annotator/providers/prediction-provider/prediction-provider.component.tsx +++ b/web_ui/src/pages/annotator/providers/prediction-provider/prediction-provider.component.tsx @@ -326,7 +326,7 @@ export const PredictionProvider = ({ selectedInput, taskId: String(selectedTask?.id), enabled: isPredictionsQueryEnabled, - onSuccess: runWhenNotDrawing(({ annotations: newRawPredictions, maps: newMaps }: PredictionResult) => { + onSuccess: runWhenNotDrawing(({ annotations: newRawPredictions }: PredictionResult) => { const selectedPredictions = selectAnnotations(newRawPredictions, userSceneSelectedInputs); if (isEmpty(userAnnotations)) { @@ -348,7 +348,6 @@ export const PredictionProvider = ({ canUpdatePrediction && userAnnotationScene.updateAnnotation(newPrediction); } - setExplanations(newMaps); setRawPredictions(selectedPredictions); // Optionally update video timeline predictions diff --git a/web_ui/src/pages/annotator/providers/prediction-provider/use-inference-server-status.ts b/web_ui/src/pages/annotator/providers/prediction-provider/use-inference-server-status.ts index 06def5b3bb..b27638d4aa 100644 --- a/web_ui/src/pages/annotator/providers/prediction-provider/use-inference-server-status.ts +++ b/web_ui/src/pages/annotator/providers/prediction-provider/use-inference-server-status.ts @@ -6,7 +6,7 @@ import { useApplicationServices } from '@geti/core/src/services/application-serv import { useQuery, UseQueryResult } from '@tanstack/react-query'; import { AxiosError } from 'axios'; -import { InferenceServerStatusResult } from '../../../../core/annotations/services/prediction-service.interface'; +import { InferenceServerStatusResult } from '../../../../core/annotations/services/inference-service.interface'; import { ProjectIdentifier } from '../../../../core/projects/core.interface'; import { useTask } from '../task-provider/task-provider.component'; diff --git a/web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.test.tsx b/web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.test.tsx index 895813f382..5b59914f27 100644 --- a/web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.test.tsx +++ b/web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.test.tsx @@ -135,7 +135,7 @@ describe('usePredictionsRoiQuery', () => { }); await waitFor(() => { - expect(result.current.data).toEqual({ annotations: [], maps: [] }); + expect(result.current.data).toEqual({ annotations: [] }); }); }); @@ -149,7 +149,6 @@ describe('usePredictionsRoiQuery', () => { await waitFor(() => { expect(result.current.data).toEqual({ annotations: [], - maps: [], }); }); @@ -181,11 +180,9 @@ describe('usePredictionsRoiQuery', () => { expect.anything() ); }); - - expect(mockedInferenceService.getExplanations).not.toHaveBeenCalled(); }); - describe('PredictionMode.ONLINE is sent as PredictionCache.NEVER, calls "getPredictions" and "getExplanations"', () => { + describe('PredictionMode.ONLINE is sent as PredictionCache.NEVER, calls "getPredictions"', () => { it('successful responses', async () => { jest.mocked(useAnnotatorMode).mockImplementation(() => ({ isActiveLearningMode: false, @@ -210,15 +207,6 @@ describe('usePredictionsRoiQuery', () => { expect.anything() ); }); - - expect(mockedInferenceService.getExplanations).toHaveBeenLastCalledWith( - datasetIdentifier, - selectedMediaItem, - taskId, - selectedInput, - // AbortController - expect.anything() - ); }); it('rejected requests are handle as empty', async () => { @@ -227,7 +215,6 @@ describe('usePredictionsRoiQuery', () => { currentMode: ANNOTATOR_MODE.PREDICTION, })); jest.mocked(mockedInferenceService.getPredictions).mockRejectedValue('test error'); - jest.mocked(mockedInferenceService.getExplanations).mockRejectedValue('test error'); const { result } = renderPredictionsRoiQuery({ selectedInput, @@ -236,7 +223,7 @@ describe('usePredictionsRoiQuery', () => { }); await waitFor(() => { - expect(result.current.data).toEqual({ annotations: [], maps: [] }); + expect(result.current.data).toEqual({ annotations: [] }); }); }); }); diff --git a/web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.ts b/web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.ts index c48c4f17ef..cb66b4d9b9 100644 --- a/web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.ts +++ b/web_ui/src/pages/annotator/providers/prediction-provider/use-prediction-roi-query.hook.ts @@ -9,11 +9,7 @@ import { useQuery } from '@tanstack/react-query'; import { AxiosError } from 'axios'; import { TaskChainInput } from '../../../../core/annotations/annotation.interface'; -import { - PredictionCache, - PredictionMode, - PredictionResult, -} from '../../../../core/annotations/services/prediction-service.interface'; +import { PredictionMode, PredictionResult } from '../../../../core/annotations/services/prediction-service.interface'; import { getPredictionCache } from '../../../../core/annotations/services/utils'; import { usePrevious } from '../../../../hooks/use-previous/use-previous.hook'; import { useProject } from '../../../project-details/providers/project-provider/project-provider.component'; @@ -45,12 +41,10 @@ export const usePredictionsRoiQuery = ({ const { inferenceService } = useApplicationServices(); const isValidRoi = roiId !== undefined && prevRoi !== roiId; - const isQueryEnabled = enabled && isValidRoi; const predictionMode = isActiveLearningMode ? PredictionMode.AUTO : PredictionMode.ONLINE; const predictionCache = getPredictionCache(predictionMode); - const isPredictionCacheNever = predictionCache === PredictionCache.NEVER; const handleSuccessRef = useRef(onSuccess); @@ -58,8 +52,6 @@ export const usePredictionsRoiQuery = ({ handleSuccessRef.current = onSuccess; }, [onSuccess]); - // TODO: extract explanation, combine with other prediction query - const query = useQuery({ queryKey: QUERY_KEYS.SELECTED_MEDIA_ITEM.PREDICTIONS( selectedMediaItem?.identifier, @@ -69,39 +61,25 @@ export const usePredictionsRoiQuery = ({ roiId ), queryFn: async ({ signal }) => { - if (isQueryEnabled === false) { - return { maps: [], annotations: [] }; + if (!isQueryEnabled || !selectedMediaItem) { + return { annotations: [] }; } - if (selectedMediaItem === undefined) { - return { maps: [], annotations: [] }; - } - - const explainPromise = isPredictionCacheNever - ? inferenceService.getExplanations(datasetIdentifier, selectedMediaItem, taskId, selectedInput, signal) - : Promise.resolve([]); - - return Promise.allSettled([ - inferenceService.getPredictions( - datasetIdentifier, - project.labels, - selectedMediaItem, - predictionCache, - taskId, - selectedInput, - signal - ), - explainPromise, - ]).then(([annotationsResponse, mapsResponse]) => { - const maps = mapsResponse.status === 'fulfilled' ? mapsResponse.value : []; - const annotations = annotationsResponse.status === 'fulfilled' ? annotationsResponse.value : []; + const annotations = await inferenceService.getPredictions( + datasetIdentifier, + project.labels, + selectedMediaItem, + predictionCache, + taskId, + selectedInput, + signal + ); - handleSuccessRef.current !== undefined && handleSuccessRef.current({ annotations, maps }); + handleSuccessRef.current?.({ annotations }); - return { annotations, maps }; - }); + return { annotations }; }, - initialData: { maps: [], annotations: [] }, + initialData: { annotations: [] }, enabled: isQueryEnabled, }); diff --git a/web_ui/src/pages/annotator/providers/selected-media-item-provider/default-selected-media-item-provider.component.tsx b/web_ui/src/pages/annotator/providers/selected-media-item-provider/default-selected-media-item-provider.component.tsx index a465dd846c..cbc9abe37a 100644 --- a/web_ui/src/pages/annotator/providers/selected-media-item-provider/default-selected-media-item-provider.component.tsx +++ b/web_ui/src/pages/annotator/providers/selected-media-item-provider/default-selected-media-item-provider.component.tsx @@ -7,6 +7,7 @@ import QUERY_KEYS from '@geti/core/src/requests/query-keys'; import { useQuery, UseQueryResult } from '@tanstack/react-query'; import { noop } from 'lodash-es'; +import { ExplanationResult } from '../../../../core/annotations/services/inference-service.interface'; import { PredictionResult } from '../../../../core/annotations/services/prediction-service.interface'; import { SelectedMediaItemContext, SelectedMediaItemProps } from './selected-media-item-provider.component'; import { SelectedMediaItem } from './selected-media-item.interface'; @@ -40,6 +41,7 @@ export const DefaultSelectedMediaItemProvider = ({ selectedMediaItemQuery, setSelectedMediaItem: noop, predictionsQuery: { isLoading: false } as UseQueryResult, + explanationsQuery: { isLoading: false } as UseQueryResult, }; return {children}; diff --git a/web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item-provider.component.tsx b/web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item-provider.component.tsx index 7632d912cb..0d3b0c9075 100644 --- a/web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item-provider.component.tsx +++ b/web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item-provider.component.tsx @@ -12,6 +12,7 @@ import { AxiosError } from 'axios'; import { isEmpty, isEqual } from 'lodash-es'; import { Annotation } from '../../../../core/annotations/annotation.interface'; +import { ExplanationResult } from '../../../../core/annotations/services/inference-service.interface'; import { PredictionMode, PredictionResult } from '../../../../core/annotations/services/prediction-service.interface'; import { InferenceModel } from '../../../../core/annotations/services/visual-prompt-service'; import { MediaItem } from '../../../../core/media/media.interface'; @@ -31,6 +32,7 @@ import { hasEmptyLabels } from '../../utils'; import { useTask } from '../task-provider/task-provider.component'; import { SelectedMediaItem } from './selected-media-item.interface'; import { useAnnotationsQuery } from './use-annotations-query.hook'; +import { useExplanationsQuery } from './use-explanation-query.hook'; import { useLoadImageQuery } from './use-load-image-query.hook'; import { usePredictionsQuery } from './use-predictions-query.hook'; import { useSelectedInferenceModel } from './use-selected-inference-model'; @@ -43,6 +45,7 @@ import { export interface SelectedMediaItemProps { predictionsQuery: UseQueryResult; + explanationsQuery: UseQueryResult; selectedMediaItem: SelectedMediaItem | undefined; selectedMediaItemQuery: UseQueryResult; setSelectedMediaItem: (media: MediaItem | undefined) => void; @@ -87,7 +90,7 @@ const useIsSuggestPredictionEnabled = (projectIdentifier: ProjectIdentifier): bo }; /** - * Load either online or auto predicitons based on the annotator mode + * Load either online or auto predictions based on the annotator mode * When the user switches between both modes we don't want to refetch predictions, * so in both cases we will keep the queries mounted */ @@ -169,10 +172,20 @@ export const SelectedMediaItemProvider = ({ children }: SelectedMediaItemProvide }); const predictionsQuery = usePredictionsQueryBasedOnAnnotatorMode(mediaItem); + const explanationsQuery = useExplanationsQuery({ + datasetIdentifier, + mediaItem, + taskId: selectedTask?.id, + }); const selectedMediaItemQueryKey = [ ...QUERY_KEYS.SELECTED_MEDIA_ITEM.SELECTED(pendingMediaItem?.identifier, selectedTask?.id), - [imageQuery.fetchStatus, annotationsQuery.fetchStatus, predictionsQuery.fetchStatus], + [ + imageQuery.fetchStatus, + annotationsQuery.fetchStatus, + predictionsQuery.fetchStatus, + explanationsQuery.fetchStatus, + ], ]; const isSelectedMediaItemQueryEnabled = mediaItem !== undefined; @@ -184,31 +197,32 @@ export const SelectedMediaItemProvider = ({ children }: SelectedMediaItemProvide throw new Error("Can't fetch undefined media item"); } - const [image, annotations, predictions] = await new Promise<[ImageData, Annotation[], PredictionResult]>( - (resolve, reject) => { - if (imageQuery.isError) { - reject({ - message: 'Failed loading media item. Please try refreshing or selecting a different item.', - }); - } + const [image, annotations, predictions, explanations] = await new Promise< + [ImageData, Annotation[], PredictionResult, ExplanationResult] + >((resolve, reject) => { + if (imageQuery.isError) { + reject({ + message: 'Failed loading media item. Please try refreshing or selecting a different item.', + }); + } - if (imageQuery.data && annotationsQuery.data) { - if (!predictionsQuery.data && predictionsQuery.isFetching) { - // If we do not yet have predictions and the user has not yet made any annotations - // for the selected task, then we will wait for predictions - if (isNotAnnotatedForTask(annotationsQuery.data, selectedTask)) { - return; - } + if (imageQuery.data && annotationsQuery.data) { + if (!predictionsQuery.data && predictionsQuery.isFetching) { + // If we do not yet have predictions and the user has not yet made any annotations + // for the selected task, then we will wait for predictions + if (isNotAnnotatedForTask(annotationsQuery.data, selectedTask)) { + return; } + } - const predictionsData = predictionsQuery.data ?? { maps: [], annotations: [] }; + const predictionsData = predictionsQuery.data ?? { annotations: [] }; + const explanationsData = explanationsQuery.data ?? []; - resolve([imageQuery.data, annotationsQuery.data, predictionsData]); - } + resolve([imageQuery.data, annotationsQuery.data, predictionsData, explanationsData]); } - ); + }); - const newlySelectedMediaItem = { ...mediaItem, image, annotations, predictions }; + const newlySelectedMediaItem = { ...mediaItem, image, annotations, predictions, explanations }; if (isSingleDomainProject(isClassificationDomain)) { return { @@ -266,6 +280,7 @@ export const SelectedMediaItemProvider = ({ children }: SelectedMediaItemProvide const value: SelectedMediaItemProps = { predictionsQuery, + explanationsQuery, selectedMediaItem, selectedMediaItemQuery, setSelectedMediaItem: setPendingMediaItem, diff --git a/web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item.interface.ts b/web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item.interface.ts index e13d701e12..50e7f602a9 100644 --- a/web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item.interface.ts +++ b/web_ui/src/pages/annotator/providers/selected-media-item-provider/selected-media-item.interface.ts @@ -2,6 +2,7 @@ // LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE import { Annotation } from '../../../../core/annotations/annotation.interface'; +import { ExplanationResult } from '../../../../core/annotations/services/inference-service.interface'; import { PredictionResult } from '../../../../core/annotations/services/prediction-service.interface'; import { MediaItem } from '../../../../core/media/media.interface'; @@ -9,4 +10,5 @@ export type SelectedMediaItem = MediaItem & { image: ImageData; annotations: Annotation[]; predictions?: PredictionResult; + explanations?: ExplanationResult; }; diff --git a/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-explanation-query.hook.ts b/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-explanation-query.hook.ts new file mode 100644 index 0000000000..554331b695 --- /dev/null +++ b/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-explanation-query.hook.ts @@ -0,0 +1,57 @@ +// Copyright (C) 2022-2025 Intel Corporation +// LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE + +import QUERY_KEYS from '@geti/core/src/requests/query-keys'; +import { useApplicationServices } from '@geti/core/src/services/application-services-provider.component'; +import { QueryKey, useQuery, UseQueryResult } from '@tanstack/react-query'; +import { AxiosError } from 'axios'; + +import { ExplanationResult } from '../../../../core/annotations/services/inference-service.interface'; +import { MediaItem } from '../../../../core/media/media.interface'; +import { DatasetIdentifier } from '../../../../core/projects/dataset.interface'; + +interface UseGetExplanations { + datasetIdentifier: DatasetIdentifier; + mediaItem: MediaItem | undefined; + enabled?: boolean; + taskId?: string; +} + +export const useExplanationsQuery = ({ + datasetIdentifier, + mediaItem, + enabled = true, + taskId, +}: UseGetExplanations): UseQueryResult => { + const { inferenceService } = useApplicationServices(); + + const queryKey: QueryKey = QUERY_KEYS.SELECTED_MEDIA_ITEM.EXPLANATIONS( + datasetIdentifier, + mediaItem?.identifier, + taskId + ); + + return useQuery({ + queryKey, + queryFn: async ({ signal }) => { + if (!mediaItem) throw new Error("Can't fetch undefined media item"); + + try { + const explanations = await inferenceService.getExplanations( + datasetIdentifier, + mediaItem, + taskId, + undefined, + signal + ); + + return explanations ?? []; + } catch (_error) { + return []; + } + }, + enabled: enabled && !!mediaItem, + staleTime: 5 * 60_000, + gcTime: 0, + }); +}; diff --git a/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-predictions-query.hook.test.tsx b/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-predictions-query.hook.test.tsx index 5519a40aec..34b69f2659 100644 --- a/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-predictions-query.hook.test.tsx +++ b/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-predictions-query.hook.test.tsx @@ -49,7 +49,6 @@ describe('usePredictionsQuery', (): void => { beforeEach(() => { mockedInferenceService.getPredictions = jest.fn(); - mockedInferenceService.getExplanations = jest.fn(); }); it('task-chain projects call the service with taskId ', async (): Promise => { @@ -69,15 +68,6 @@ describe('usePredictionsQuery', (): void => { expect(result.current).toBeDefined(); }); - expect(mockedInferenceService.getExplanations).toHaveBeenCalledWith( - datasetIdentifier, - mediaItem, - taskId, - undefined, - - // AbortController - expect.anything() - ); expect(mockedInferenceService.getPredictions).toHaveBeenCalledWith( datasetIdentifier, coreLabels, @@ -97,7 +87,6 @@ describe('usePredictionsQuery', (): void => { }); await waitFor(() => { - expect(mockedInferenceService.getExplanations).not.toHaveBeenCalled(); expect(mockedInferenceService.getPredictions).toHaveBeenCalledWith( datasetIdentifier, coreLabels, @@ -111,7 +100,7 @@ describe('usePredictionsQuery', (): void => { }); }); - it('PredictionMode LATEST is handle as PredictionCache.ALWAYS', async (): Promise => { + it('PredictionMode LATEST is handled as PredictionCache.ALWAYS', async (): Promise => { const onSuccess = jest.fn(); renderHookWithProviders( @@ -126,7 +115,6 @@ describe('usePredictionsQuery', (): void => { expect(onSuccess).toHaveBeenCalled(); }); - expect(mockedInferenceService.getExplanations).not.toHaveBeenCalled(); expect(mockedInferenceService.getPredictions).toHaveBeenCalledWith( datasetIdentifier, coreLabels, @@ -139,9 +127,11 @@ describe('usePredictionsQuery', (): void => { ); }); - it('PredictionMode.ONLINE is sent as PredictionCache.NEVER, getExplanations is called', async (): Promise => { + it('PredictionMode.ONLINE is handled as PredictionCache.NEVER', async (): Promise => { + const onSuccess = jest.fn(); + renderHookWithProviders( - () => usePredictionsQuery({ ...predictionArguments, predictionId: PredictionMode.ONLINE }), + () => usePredictionsQuery({ ...predictionArguments, onSuccess, predictionId: PredictionMode.ONLINE }), { wrapper, providerProps: { ...initialProps }, @@ -149,14 +139,7 @@ describe('usePredictionsQuery', (): void => { ); await waitFor(() => { - expect(mockedInferenceService.getExplanations).toHaveBeenCalledWith( - datasetIdentifier, - mediaItem, - undefined, - undefined, - // AbortController - expect.anything() - ); + expect(onSuccess).toHaveBeenCalled(); }); expect(mockedInferenceService.getPredictions).toHaveBeenCalledWith( @@ -170,25 +153,4 @@ describe('usePredictionsQuery', (): void => { expect.anything() ); }); - - it('does not call "getExplanations" for keypoint detection projects', async (): Promise => { - const mockedProjectService = createInMemoryProjectService(); - - mockedProjectService.getProject = async () => - getMockedProject({ tasks: [getMockedTask({ domain: DOMAIN.KEYPOINT_DETECTION })] }); - - renderHookWithProviders( - () => usePredictionsQuery({ ...predictionArguments, predictionId: PredictionMode.ONLINE }), - { - wrapper, - providerProps: { ...initialProps, projectService: mockedProjectService }, - } - ); - - await waitFor(() => { - expect(mockedInferenceService.getPredictions).toHaveBeenCalled(); - }); - - expect(mockedInferenceService.getExplanations).not.toHaveBeenCalled(); - }); }); diff --git a/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-predictions-query.hook.ts b/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-predictions-query.hook.ts index 203db00e86..1b7053b51b 100644 --- a/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-predictions-query.hook.ts +++ b/web_ui/src/pages/annotator/providers/selected-media-item-provider/use-predictions-query.hook.ts @@ -21,7 +21,6 @@ import { MEDIA_TYPE } from '../../../../core/media/base-media.interface'; import { MediaItem } from '../../../../core/media/media.interface'; import { isVideoFrame } from '../../../../core/media/video.interface'; import { DatasetIdentifier } from '../../../../core/projects/dataset.interface'; -import { isKeypointDetection } from '../../../../core/projects/domains'; import { useTask } from '../task-provider/task-provider.component'; import { updateVideoTimelineQuery } from './utils'; @@ -55,12 +54,8 @@ export const usePredictionsQuery = ({ const isTaskChainProject = tasks.length > 1; const predictionCache = getPredictionCache(predictionId); const isPredictionCacheAuto = predictionCache === PredictionCache.AUTO; - const isPredictionCacheNever = predictionCache === PredictionCache.NEVER; const isQueryEnabled = enabled && mediaItem !== undefined; - const isKeypointDetectionTask = selectedTask !== null && isKeypointDetection(selectedTask.domain); - const shouldFetchExplanations = isPredictionCacheNever && !isKeypointDetectionTask; - const queryKey: QueryKey = [ ...QUERY_KEYS.SELECTED_MEDIA_ITEM.PREDICTIONS(mediaItem?.identifier, 'initial', predictionCache, taskId), predictionId, @@ -69,7 +64,6 @@ export const usePredictionsQuery = ({ const handleSuccessRef = useRef(onSuccess); const handleErrorRef = useRef(onError); - // TODO: extract explanations const predictionsQuery = useQuery({ queryKey, queryFn: async ({ signal }) => { @@ -85,24 +79,12 @@ export const usePredictionsQuery = ({ taskId ?? tasks[0].id, signal ); - return { - maps: [], annotations, }; } - const explainPromise = shouldFetchExplanations - ? inferenceService.getExplanations( - datasetIdentifier, - mediaItem, - isTaskChainProject ? taskId : undefined, - undefined, - signal - ) - : Promise.resolve([]); - - const predictionsPromise = inferenceService.getPredictions( + const annotations = await inferenceService.getPredictions( datasetIdentifier, coreLabels, mediaItem, @@ -112,14 +94,8 @@ export const usePredictionsQuery = ({ signal ); - const [annotationsResponse, explanationResponse] = await Promise.allSettled([ - predictionsPromise, - explainPromise, - ]); - return { - maps: explanationResponse.status === 'fulfilled' ? explanationResponse.value : [], - annotations: annotationsResponse.status === 'fulfilled' ? annotationsResponse.value : [], + annotations, }; }, enabled: isQueryEnabled, diff --git a/web_ui/src/pages/project-details/components/project-model/training-dataset/training-dataset-details-preview/training-dataset-details-preview.component.tsx b/web_ui/src/pages/project-details/components/project-model/training-dataset/training-dataset-details-preview/training-dataset-details-preview.component.tsx index b665551e2f..41213fce71 100644 --- a/web_ui/src/pages/project-details/components/project-model/training-dataset/training-dataset-details-preview/training-dataset-details-preview.component.tsx +++ b/web_ui/src/pages/project-details/components/project-model/training-dataset/training-dataset-details-preview/training-dataset-details-preview.component.tsx @@ -32,6 +32,7 @@ import { DatasetList } from '../../../../../annotator/components/sidebar/dataset import { ANNOTATOR_MODE } from '../../../../../annotator/core/annotation-tool-context.interface'; import { AnnotationToolProvider } from '../../../../../annotator/providers/annotation-tool-provider/annotation-tool-provider.component'; import { useAnnotationsQuery } from '../../../../../annotator/providers/selected-media-item-provider/use-annotations-query.hook'; +import { useExplanationsQuery } from '../../../../../annotator/providers/selected-media-item-provider/use-explanation-query.hook'; import { useLoadImageQuery } from '../../../../../annotator/providers/selected-media-item-provider/use-load-image-query.hook'; import { usePredictionsQuery } from '../../../../../annotator/providers/selected-media-item-provider/use-predictions-query.hook'; import { useProject } from '../../../../providers/project-provider/project-provider.component'; @@ -98,9 +99,14 @@ export const TrainingDatasetDetailsPreview = ({ enabled: isPredictionsEnabled, predictionId: isVisualPrompt ? PredictionMode.VISUAL_PROMPT : undefined, }); + const { data: explanations } = useExplanationsQuery({ + mediaItem: selectedPreviewItem, + datasetIdentifier, + taskId, + enabled: isPredictionsEnabled, + }); const predictions = useVisibleAnnotations(predictionsQuery.data?.annotations ?? []); - const explanations = predictionsQuery.data?.maps; const annotations = useVisibleAnnotations(annotationsQuery.data ?? []); diff --git a/web_ui/src/pages/project-details/components/project-test/test-details-preview/test-details-preview.component.tsx b/web_ui/src/pages/project-details/components/project-test/test-details-preview/test-details-preview.component.tsx index e1b3a21850..01c8300579 100644 --- a/web_ui/src/pages/project-details/components/project-test/test-details-preview/test-details-preview.component.tsx +++ b/web_ui/src/pages/project-details/components/project-test/test-details-preview/test-details-preview.component.tsx @@ -199,7 +199,7 @@ export const TestDetailsPreview = ({ const datasetIdentifier = useTestDatasetIdentifier(test); - const { imageQuery, annotationsQuery, predictionsQuery, testResult } = useTestResultsQuery( + const { imageQuery, annotationsQuery, predictionsQuery, explanationsQuery, testResult } = useTestResultsQuery( datasetIdentifier, selectedMediaItem, testMediaItem, @@ -210,7 +210,7 @@ export const TestDetailsPreview = ({ const annotations = useVisibleAnnotations(annotationsQuery.data ?? []); const predictions = useVisibleAnnotations(predictionsQuery.data?.annotations ?? []); - const explanations = sortExplanationsByName(predictionsQuery.data?.maps); + const explanations = sortExplanationsByName(explanationsQuery.data ?? []); useEffect(() => { setSelectedMediaItem(testMediaItem.media); diff --git a/web_ui/src/pages/project-details/components/project-test/test-details-preview/use-test-results-query.hook.test.tsx b/web_ui/src/pages/project-details/components/project-test/test-details-preview/use-test-results-query.hook.test.tsx index 266cafeaea..e7700cf44f 100644 --- a/web_ui/src/pages/project-details/components/project-test/test-details-preview/use-test-results-query.hook.test.tsx +++ b/web_ui/src/pages/project-details/components/project-test/test-details-preview/use-test-results-query.hook.test.tsx @@ -80,7 +80,8 @@ describe('useTestResultsQuery', () => { await waitFor(() => { expect(mockedInferenceService.getExplanations).toHaveBeenCalled(); expect(mockedInferenceService.getTestPredictions).toHaveBeenCalled(); - expect(result.current.predictionsQuery.data).toEqual({ annotations: [mockAnnotation], maps: [] }); + expect(result.current.predictionsQuery.data).toEqual({ annotations: [mockAnnotation] }); + expect(result.current.explanationsQuery.data).toEqual([]); }); }); @@ -103,7 +104,8 @@ describe('useTestResultsQuery', () => { await waitFor(() => { expect(mockedInferenceService.getExplanations).toHaveBeenCalled(); expect(mockedInferenceService.getTestPredictions).toHaveBeenCalled(); - expect(result.current.predictionsQuery.data).toEqual({ annotations: [], maps: [mockedExplanation] }); + expect(result.current.predictionsQuery.data).toEqual({ annotations: [] }); + expect(result.current.explanationsQuery.data).toEqual([mockedExplanation]); }); }); }); diff --git a/web_ui/src/pages/project-details/components/project-test/test-details-preview/use-test-results-query.hook.ts b/web_ui/src/pages/project-details/components/project-test/test-details-preview/use-test-results-query.hook.ts index 96c40c7858..ddf314357a 100644 --- a/web_ui/src/pages/project-details/components/project-test/test-details-preview/use-test-results-query.hook.ts +++ b/web_ui/src/pages/project-details/components/project-test/test-details-preview/use-test-results-query.hook.ts @@ -15,6 +15,7 @@ import { TestMediaItem } from '../../../../../core/tests/test-media.interface'; import { useProjectIdentifier } from '../../../../../hooks/use-project-identifier/use-project-identifier'; import { isNonEmptyString } from '../../../../../shared/utils'; import { useAnnotationsQuery } from '../../../../annotator/providers/selected-media-item-provider/use-annotations-query.hook'; +import { useExplanationsQuery } from '../../../../annotator/providers/selected-media-item-provider/use-explanation-query.hook'; import { useLoadImageQuery } from '../../../../annotator/providers/selected-media-item-provider/use-load-image-query.hook'; import { useProject } from '../../../providers/project-provider/project-provider.component'; @@ -71,27 +72,33 @@ export const useTestResultsQuery = ( const predictionsQuery = useQuery({ queryKey: QUERY_KEYS.TEST_PREDICTIONS(projectIdentifier, testId, String(testResult?.predictionId)), - queryFn: () => - Promise.allSettled([ - inferenceService.getTestPredictions( + queryFn: async () => { + try { + const predictions = await inferenceService.getTestPredictions( datasetIdentifier, labels, testId, String(testResult?.predictionId) - ), - inferenceService.getExplanations(datasetIdentifier, mediaItem), - ]).then(([predictions, explanations]) => { - // @ts-expect-error PromiseSettledResult type is not exported - // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Promise/allSettled - return { annotations: predictions.value ?? [], maps: explanations.value ?? [] }; - }), + ); + + return { annotations: predictions ?? [] }; + } catch (_error) { + return { annotations: [] }; + } + }, enabled: isNonEmptyString(testResult?.predictionId), }); + const explanationsQuery = useExplanationsQuery({ + datasetIdentifier, + mediaItem, + }); + return { imageQuery, annotationsQuery, predictionsQuery, + explanationsQuery, testResult, }; };