Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions web_ui/packages/core/src/requests/query-keys.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ const getSelectedMediaItemQueryKeys = () => {
taskId?: string,
roiId?: string
) => [...commonKey(mediaIdentifier), taskId, roiId, `${prefix}-predictions`, predictionCache],

EXPLANATIONS: (
datasetIdentifier: DatasetIdentifier,
mediaIdentifier: MediaIdentifier | undefined,
taskId?: string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also depend on the roi? Iirc when we are in task chain we request the explanation map based on a given ROI.

) => [...commonKey(mediaIdentifier), 'explanations', datasetIdentifier, mediaIdentifier, taskId],

SELECTED: (mediaIdentifier: MediaIdentifier | undefined, taskId?: string) => [
...commonKey(mediaIdentifier),
taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ import { ProjectIdentifier } from '../../projects/core.interface';
import { DatasetIdentifier } from '../../projects/dataset.interface';
import { Annotation, TaskChainInput } from '../annotation.interface';
import { Explanation } from '../prediction.interface';
import { InferenceServerStatusResult, PredictionCache, PredictionMode } from './prediction-service.interface';
import { PredictionCache, PredictionMode } from './prediction-service.interface';
import { VideoPaginationOptions } from './video-pagination-options.interface';

export type InferenceResult = ReadonlyArray<Annotation>;
export type ExplanationResult = Explanation[];

export interface InferenceServerStatusResult {
isInferenceServerReady: boolean;
}

export interface InferenceService {
getTestPredictions: (
projectIdentifier: ProjectIdentifier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ import {
PipelineServerStatusDTO,
} from '../../dtos/prediction.interface';
import { Rect } from '../../shapes.interface';
import { ExplanationResult, InferenceResult, InferenceService } from '../inference-service.interface';
import { InferenceServerStatusResult, PredictionCache, PredictionMode } from '../prediction-service.interface';
import {
ExplanationResult,
InferenceResult,
InferenceServerStatusResult,
InferenceService,
} from '../inference-service.interface';
import { PredictionCache, PredictionMode } from '../prediction-service.interface';
import {
buildPredictionParams,
getExplanations as convertExplanations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,9 @@
// LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE

import { Annotation } from '../annotation.interface';
import { Explanation } from '../prediction.interface';

export interface PredictionResult {
annotations: ReadonlyArray<Annotation>;
maps: Explanation[];
}

export interface InferenceServerStatusResult {
isInferenceServerReady: boolean;
}

export enum PredictionMode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ describe('LayersFactory', () => {
selectedMediaItem,
selectedMediaItemQuery: { isLoading: false },
predictionsQuery: { data: undefined },
explanationsQuery: { data: undefined },
} as SelectedMediaItemProps);

const response = await annotatorRender(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ export const AnnotatorProvider = ({ children }: AnnotatorProviderProps): JSX.Ele
settings={userProjectSettings}
userAnnotationScene={userAnnotationScene}
initPredictions={initialPredictionAnnotations}
explanations={selectedMediaItem?.predictions?.maps || EMPTY_EXPLANATION}
explanations={selectedMediaItem?.explanations || EMPTY_EXPLANATION}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to

Suggested change
explanations={selectedMediaItem?.explanations || EMPTY_EXPLANATION}
explanations={explanationQuery.data?.explanations || EMPTY_EXPLANATION}
// or
explanationsQuery={explanationsQuery}

where we can define the query inside of this function or possibly remove the argument and use this query directly inside of the prediction provider.

>
<SubmitAnnotationsProvider
settings={userProjectSettings}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ const geSelectedMediaItem = (predictions: Annotation[] = [], annotations: Annota
...getMockedImageMediaItem({}),
image: getMockedImage(),
annotations,
predictions: { annotations: predictions, maps: [] },
predictions: { annotations: predictions },
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think on a different PR we can also simplify this interface.

explanations: [],
});
const getInitialPredictionConfig = (isEnabled = true) => ({
[FEATURES_KEYS.INITIAL_PREDICTION]: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ export const PredictionProvider = ({
selectedInput,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding this file: I think we can refactor this provider so that we no longer need to keep track of the explanations inside of it. Instead it should be possible to retrieve the explanations directly from a query result.

taskId: String(selectedTask?.id),
enabled: isPredictionsQueryEnabled,
onSuccess: runWhenNotDrawing(({ annotations: newRawPredictions, maps: newMaps }: PredictionResult) => {
onSuccess: runWhenNotDrawing(({ annotations: newRawPredictions }: PredictionResult) => {
const selectedPredictions = selectAnnotations(newRawPredictions, userSceneSelectedInputs);

if (isEmpty(userAnnotations)) {
Expand All @@ -348,7 +348,6 @@ export const PredictionProvider = ({
canUpdatePrediction && userAnnotationScene.updateAnnotation(newPrediction);
}

setExplanations(newMaps);
setRawPredictions(selectedPredictions);

// Optionally update video timeline predictions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { useApplicationServices } from '@geti/core/src/services/application-serv
import { useQuery, UseQueryResult } from '@tanstack/react-query';
import { AxiosError } from 'axios';

import { InferenceServerStatusResult } from '../../../../core/annotations/services/prediction-service.interface';
import { InferenceServerStatusResult } from '../../../../core/annotations/services/inference-service.interface';
import { ProjectIdentifier } from '../../../../core/projects/core.interface';
import { useTask } from '../task-provider/task-provider.component';

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ describe('usePredictionsRoiQuery', () => {
});

await waitFor(() => {
expect(result.current.data).toEqual({ annotations: [], maps: [] });
expect(result.current.data).toEqual({ annotations: [] });
});
});

Expand All @@ -149,7 +149,6 @@ describe('usePredictionsRoiQuery', () => {
await waitFor(() => {
expect(result.current.data).toEqual({
annotations: [],
maps: [],
});
});

Expand Down Expand Up @@ -181,11 +180,9 @@ describe('usePredictionsRoiQuery', () => {
expect.anything()
);
});

expect(mockedInferenceService.getExplanations).not.toHaveBeenCalled();
});

describe('PredictionMode.ONLINE is sent as PredictionCache.NEVER, calls "getPredictions" and "getExplanations"', () => {
describe('PredictionMode.ONLINE is sent as PredictionCache.NEVER, calls "getPredictions"', () => {
it('successful responses', async () => {
jest.mocked(useAnnotatorMode).mockImplementation(() => ({
isActiveLearningMode: false,
Expand All @@ -210,15 +207,6 @@ describe('usePredictionsRoiQuery', () => {
expect.anything()
);
});

expect(mockedInferenceService.getExplanations).toHaveBeenLastCalledWith(
datasetIdentifier,
selectedMediaItem,
taskId,
selectedInput,
// AbortController
expect.anything()
);
});

it('rejected requests are handle as empty', async () => {
Expand All @@ -227,7 +215,6 @@ describe('usePredictionsRoiQuery', () => {
currentMode: ANNOTATOR_MODE.PREDICTION,
}));
jest.mocked(mockedInferenceService.getPredictions).mockRejectedValue('test error');
jest.mocked(mockedInferenceService.getExplanations).mockRejectedValue('test error');

const { result } = renderPredictionsRoiQuery({
selectedInput,
Expand All @@ -236,7 +223,7 @@ describe('usePredictionsRoiQuery', () => {
});

await waitFor(() => {
expect(result.current.data).toEqual({ annotations: [], maps: [] });
expect(result.current.data).toEqual({ annotations: [] });
});
});
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ import { useQuery } from '@tanstack/react-query';
import { AxiosError } from 'axios';

import { TaskChainInput } from '../../../../core/annotations/annotation.interface';
import {
PredictionCache,
PredictionMode,
PredictionResult,
} from '../../../../core/annotations/services/prediction-service.interface';
import { PredictionMode, PredictionResult } from '../../../../core/annotations/services/prediction-service.interface';
import { getPredictionCache } from '../../../../core/annotations/services/utils';
import { usePrevious } from '../../../../hooks/use-previous/use-previous.hook';
import { useProject } from '../../../project-details/providers/project-provider/project-provider.component';
Expand Down Expand Up @@ -45,21 +41,17 @@ export const usePredictionsRoiQuery = ({
const { inferenceService } = useApplicationServices();

const isValidRoi = roiId !== undefined && prevRoi !== roiId;

const isQueryEnabled = enabled && isValidRoi;

const predictionMode = isActiveLearningMode ? PredictionMode.AUTO : PredictionMode.ONLINE;
const predictionCache = getPredictionCache(predictionMode);
const isPredictionCacheNever = predictionCache === PredictionCache.NEVER;

const handleSuccessRef = useRef(onSuccess);

useLayoutEffect(() => {
handleSuccessRef.current = onSuccess;
}, [onSuccess]);

// TODO: extract explanation, combine with other prediction query

const query = useQuery<PredictionResult, AxiosError>({
queryKey: QUERY_KEYS.SELECTED_MEDIA_ITEM.PREDICTIONS(
selectedMediaItem?.identifier,
Expand All @@ -69,39 +61,25 @@ export const usePredictionsRoiQuery = ({
roiId
),
queryFn: async ({ signal }) => {
if (isQueryEnabled === false) {
return { maps: [], annotations: [] };
if (!isQueryEnabled || !selectedMediaItem) {
return { annotations: [] };
}

if (selectedMediaItem === undefined) {
return { maps: [], annotations: [] };
}

const explainPromise = isPredictionCacheNever
? inferenceService.getExplanations(datasetIdentifier, selectedMediaItem, taskId, selectedInput, signal)
: Promise.resolve([]);

return Promise.allSettled([
inferenceService.getPredictions(
datasetIdentifier,
project.labels,
selectedMediaItem,
predictionCache,
taskId,
selectedInput,
signal
),
explainPromise,
]).then(([annotationsResponse, mapsResponse]) => {
const maps = mapsResponse.status === 'fulfilled' ? mapsResponse.value : [];
const annotations = annotationsResponse.status === 'fulfilled' ? annotationsResponse.value : [];
const annotations = await inferenceService.getPredictions(
datasetIdentifier,
project.labels,
selectedMediaItem,
predictionCache,
taskId,
selectedInput,
signal
);

handleSuccessRef.current !== undefined && handleSuccessRef.current({ annotations, maps });
handleSuccessRef.current?.({ annotations });

return { annotations, maps };
});
return { annotations };
},
initialData: { maps: [], annotations: [] },
initialData: { annotations: [] },
enabled: isQueryEnabled,
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import QUERY_KEYS from '@geti/core/src/requests/query-keys';
import { useQuery, UseQueryResult } from '@tanstack/react-query';
import { noop } from 'lodash-es';

import { ExplanationResult } from '../../../../core/annotations/services/inference-service.interface';
import { PredictionResult } from '../../../../core/annotations/services/prediction-service.interface';
import { SelectedMediaItemContext, SelectedMediaItemProps } from './selected-media-item-provider.component';
import { SelectedMediaItem } from './selected-media-item.interface';
Expand Down Expand Up @@ -40,6 +41,7 @@ export const DefaultSelectedMediaItemProvider = ({
selectedMediaItemQuery,
setSelectedMediaItem: noop,
predictionsQuery: { isLoading: false } as UseQueryResult<PredictionResult>,
explanationsQuery: { isLoading: false } as UseQueryResult<ExplanationResult>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to completely remove this and instead use the query directly?

};

return <SelectedMediaItemContext.Provider value={value}>{children}</SelectedMediaItemContext.Provider>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { AxiosError } from 'axios';
import { isEmpty, isEqual } from 'lodash-es';

import { Annotation } from '../../../../core/annotations/annotation.interface';
import { ExplanationResult } from '../../../../core/annotations/services/inference-service.interface';
import { PredictionMode, PredictionResult } from '../../../../core/annotations/services/prediction-service.interface';
import { InferenceModel } from '../../../../core/annotations/services/visual-prompt-service';
import { MediaItem } from '../../../../core/media/media.interface';
Expand All @@ -31,6 +32,7 @@ import { hasEmptyLabels } from '../../utils';
import { useTask } from '../task-provider/task-provider.component';
import { SelectedMediaItem } from './selected-media-item.interface';
import { useAnnotationsQuery } from './use-annotations-query.hook';
import { useExplanationsQuery } from './use-explanation-query.hook';
import { useLoadImageQuery } from './use-load-image-query.hook';
import { usePredictionsQuery } from './use-predictions-query.hook';
import { useSelectedInferenceModel } from './use-selected-inference-model';
Expand All @@ -43,6 +45,7 @@ import {

export interface SelectedMediaItemProps {
predictionsQuery: UseQueryResult<PredictionResult>;
explanationsQuery: UseQueryResult<ExplanationResult>;
selectedMediaItem: SelectedMediaItem | undefined;
selectedMediaItemQuery: UseQueryResult<SelectedMediaItem>;
setSelectedMediaItem: (media: MediaItem | undefined) => void;
Expand Down Expand Up @@ -87,7 +90,7 @@ const useIsSuggestPredictionEnabled = (projectIdentifier: ProjectIdentifier): bo
};

/**
* Load either online or auto predicitons based on the annotator mode
* Load either online or auto predictions based on the annotator mode
* When the user switches between both modes we don't want to refetch predictions,
* so in both cases we will keep the queries mounted
*/
Expand Down Expand Up @@ -169,10 +172,20 @@ export const SelectedMediaItemProvider = ({ children }: SelectedMediaItemProvide
});

const predictionsQuery = usePredictionsQueryBasedOnAnnotatorMode(mediaItem);
const explanationsQuery = useExplanationsQuery({
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar question/comment as above: I'd prefer if we can fully remove it from this provider.
The logic in here is already a bit too complex with trying to make a dependant query based on three other queries, adding one more could result in some more subtle bugs.

datasetIdentifier,
mediaItem,
taskId: selectedTask?.id,
});

const selectedMediaItemQueryKey = [
...QUERY_KEYS.SELECTED_MEDIA_ITEM.SELECTED(pendingMediaItem?.identifier, selectedTask?.id),
[imageQuery.fetchStatus, annotationsQuery.fetchStatus, predictionsQuery.fetchStatus],
[
imageQuery.fetchStatus,
annotationsQuery.fetchStatus,
predictionsQuery.fetchStatus,
explanationsQuery.fetchStatus,
],
];

const isSelectedMediaItemQueryEnabled = mediaItem !== undefined;
Expand All @@ -184,31 +197,32 @@ export const SelectedMediaItemProvider = ({ children }: SelectedMediaItemProvide
throw new Error("Can't fetch undefined media item");
}

const [image, annotations, predictions] = await new Promise<[ImageData, Annotation[], PredictionResult]>(
(resolve, reject) => {
if (imageQuery.isError) {
reject({
message: 'Failed loading media item. Please try refreshing or selecting a different item.',
});
}
const [image, annotations, predictions, explanations] = await new Promise<
[ImageData, Annotation[], PredictionResult, ExplanationResult]
>((resolve, reject) => {
if (imageQuery.isError) {
reject({
message: 'Failed loading media item. Please try refreshing or selecting a different item.',
});
}

if (imageQuery.data && annotationsQuery.data) {
if (!predictionsQuery.data && predictionsQuery.isFetching) {
// If we do not yet have predictions and the user has not yet made any annotations
// for the selected task, then we will wait for predictions
if (isNotAnnotatedForTask(annotationsQuery.data, selectedTask)) {
return;
}
if (imageQuery.data && annotationsQuery.data) {
if (!predictionsQuery.data && predictionsQuery.isFetching) {
// If we do not yet have predictions and the user has not yet made any annotations
// for the selected task, then we will wait for predictions
if (isNotAnnotatedForTask(annotationsQuery.data, selectedTask)) {
return;
}
}

const predictionsData = predictionsQuery.data ?? { maps: [], annotations: [] };
const predictionsData = predictionsQuery.data ?? { annotations: [] };
const explanationsData = explanationsQuery.data ?? [];

resolve([imageQuery.data, annotationsQuery.data, predictionsData]);
}
resolve([imageQuery.data, annotationsQuery.data, predictionsData, explanationsData]);
}
);
});

const newlySelectedMediaItem = { ...mediaItem, image, annotations, predictions };
const newlySelectedMediaItem = { ...mediaItem, image, annotations, predictions, explanations };

if (isSingleDomainProject(isClassificationDomain)) {
return {
Expand Down Expand Up @@ -266,6 +280,7 @@ export const SelectedMediaItemProvider = ({ children }: SelectedMediaItemProvide

const value: SelectedMediaItemProps = {
predictionsQuery,
explanationsQuery,
selectedMediaItem,
selectedMediaItemQuery,
setSelectedMediaItem: setPendingMediaItem,
Expand Down
Loading
Loading