Skip to content

Commit d882424

Browse files
authored
[Question Answering] Cache support for Model Overview Metrics (#2166)
1 parent 4506b0d commit d882424

File tree

11 files changed

+182
-94
lines changed

11 files changed

+182
-94
lines changed

apps/widget/src/app/ModelAssessment.tsx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,15 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
5151
};
5252
callBack.requestQuestionAnsweringMetrics = async (
5353
selectionIndexes: number[][],
54+
questionAnsweringCache: Map<
55+
string,
56+
[number, number, number, number, number, number]
57+
>,
5458
abortSignal: AbortSignal
5559
): Promise<any[]> => {
5660
return callFlaskService(
5761
this.props.config,
58-
[selectionIndexes],
62+
[selectionIndexes, questionAnsweringCache],
5963
"/get_question_answering_metrics",
6064
abortSignal
6165
);

libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ export interface IModelAssessmentContext {
153153
requestQuestionAnsweringMetrics?:
154154
| ((
155155
selectionIndexes: number[][],
156+
questionAnsweringCache: Map<
157+
string,
158+
[number, number, number, number, number, number]
159+
>,
156160
abortSignal: AbortSignal
157161
) => Promise<any[]>)
158162
| undefined;

libs/core-ui/src/lib/util/ImageStatisticsUtils.ts

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT License.
33

4+
import { localization } from "@responsible-ai/localization";
5+
6+
import {
7+
ILabeledStatistic,
8+
TotalCohortSamples
9+
} from "../Interfaces/IStatistic";
10+
411
export enum ImageClassificationMetrics {
512
Accuracy = "accuracy",
613
MacroF1 = "f1",
@@ -56,3 +63,66 @@ export const generateMicroMacroMetrics = (
5663
microScore
5764
};
5865
};
66+
67+
export const generateImageStats: (
68+
trueYs: number[],
69+
predYs: number[]
70+
) => ILabeledStatistic[] = (
71+
trueYs: number[],
72+
predYs: number[]
73+
): ILabeledStatistic[] => {
74+
const correctCount = predYs.filter(
75+
(pred, index) => pred === trueYs[index]
76+
).length;
77+
const accuracy = correctCount / predYs.length;
78+
const precision = generateMicroMacroMetrics(predYs, trueYs);
79+
const microP = precision.microScore;
80+
const macroP = precision.macroScore;
81+
const recall = generateMicroMacroMetrics(trueYs, predYs);
82+
const microR = recall.microScore;
83+
const macroR = recall.macroScore;
84+
const microF1 = 2 * ((microP * microR) / (microP + microR)) || 0;
85+
const macroF1 = 2 * ((macroP * macroR) / (macroP + macroR)) || 0;
86+
return [
87+
{
88+
key: TotalCohortSamples,
89+
label: localization.Interpret.Statistics.samples,
90+
stat: predYs.length
91+
},
92+
{
93+
key: ImageClassificationMetrics.Accuracy,
94+
label: localization.Interpret.Statistics.accuracy,
95+
stat: accuracy
96+
},
97+
{
98+
key: ImageClassificationMetrics.MicroPrecision,
99+
label: localization.Interpret.Statistics.precision,
100+
stat: microP
101+
},
102+
{
103+
key: ImageClassificationMetrics.MicroRecall,
104+
label: localization.Interpret.Statistics.recall,
105+
stat: microR
106+
},
107+
{
108+
key: ImageClassificationMetrics.MicroF1,
109+
label: localization.Interpret.Statistics.f1Score,
110+
stat: microF1
111+
},
112+
{
113+
key: ImageClassificationMetrics.MacroPrecision,
114+
label: localization.Interpret.Statistics.precision,
115+
stat: macroP
116+
},
117+
{
118+
key: ImageClassificationMetrics.MacroRecall,
119+
label: localization.Interpret.Statistics.recall,
120+
stat: macroR
121+
},
122+
{
123+
key: ImageClassificationMetrics.MacroF1,
124+
label: localization.Interpret.Statistics.f1Score,
125+
stat: macroF1
126+
}
127+
];
128+
};

libs/core-ui/src/lib/util/QuestionAnsweringStatisticsUtils.ts

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,33 @@ export enum QuestionAnsweringMetrics {
1818
}
1919

2020
export const generateQuestionAnsweringStats: (
21-
selectionIndexes: number[][]
21+
selectionIndexes: number[][],
22+
questionAnsweringCache: Map<
23+
string,
24+
[number, number, number, number, number, number]
25+
>
2226
) => ILabeledStatistic[][] = (
23-
selectionIndexes: number[][]
27+
selectionIndexes: number[][],
28+
questionAnsweringCache: Map<
29+
string,
30+
[number, number, number, number, number, number]
31+
>
2432
): ILabeledStatistic[][] => {
2533
return selectionIndexes.map((selectionArray) => {
2634
const count = selectionArray.length;
2735

36+
const value = questionAnsweringCache.get(selectionArray.toString());
37+
const stat = value
38+
? value
39+
: [
40+
Number.NaN,
41+
Number.NaN,
42+
Number.NaN,
43+
Number.NaN,
44+
Number.NaN,
45+
Number.NaN
46+
];
47+
2848
return [
2949
{
3050
key: TotalCohortSamples,
@@ -34,32 +54,32 @@ export const generateQuestionAnsweringStats: (
3454
{
3555
key: QuestionAnsweringMetrics.ExactMatchRatio,
3656
label: localization.Interpret.Statistics.exactMatchRatio,
37-
stat: Number.NaN
57+
stat: stat[0]
3858
},
3959
{
4060
key: QuestionAnsweringMetrics.F1Score,
4161
label: localization.Interpret.Statistics.f1Score,
42-
stat: Number.NaN
62+
stat: stat[1]
4363
},
4464
{
4565
key: QuestionAnsweringMetrics.MeteorScore,
4666
label: localization.Interpret.Statistics.meteorScore,
47-
stat: Number.NaN
67+
stat: stat[2]
4868
},
4969
{
5070
key: QuestionAnsweringMetrics.BleuScore,
5171
label: localization.Interpret.Statistics.bleuScore,
52-
stat: Number.NaN
72+
stat: stat[3]
5373
},
5474
{
5575
key: QuestionAnsweringMetrics.BertScore,
5676
label: localization.Interpret.Statistics.bertScore,
57-
stat: Number.NaN
77+
stat: stat[4]
5878
},
5979
{
6080
key: QuestionAnsweringMetrics.RougeScore,
6181
label: localization.Interpret.Statistics.rougeScore,
62-
stat: Number.NaN
82+
stat: stat[5]
6383
}
6484
];
6585
});

libs/core-ui/src/lib/util/StatisticsUtils.ts

Lines changed: 15 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@ import {
1010
} from "../Interfaces/IStatistic";
1111
import { IsBinary } from "../util/ExplanationUtils";
1212

13-
import {
14-
generateMicroMacroMetrics,
15-
ImageClassificationMetrics
16-
} from "./ImageStatisticsUtils";
13+
import { generateImageStats } from "./ImageStatisticsUtils";
1714
import { JointDataset } from "./JointDataset";
1815
import {
1916
ClassificationEnum,
@@ -28,6 +25,11 @@ import {
2825
RegressionMetrics
2926
} from "./StatisticsUtilsEnums";
3027

28+
type QuestionAnsweringCacheType = Map<
29+
string,
30+
[number, number, number, number, number, number]
31+
>;
32+
3133
const generateBinaryStats: (outcomes: number[]) => ILabeledStatistic[] = (
3234
outcomes: number[]
3335
): ILabeledStatistic[] => {
@@ -166,91 +168,32 @@ const generateMulticlassStats: (outcomes: number[]) => ILabeledStatistic[] = (
166168
];
167169
};
168170

169-
const generateImageStats: (
170-
trueYs: number[],
171-
predYs: number[]
172-
) => ILabeledStatistic[] = (
173-
trueYs: number[],
174-
predYs: number[]
175-
): ILabeledStatistic[] => {
176-
const correctCount = predYs.filter(
177-
(pred, index) => pred === trueYs[index]
178-
).length;
179-
const accuracy = correctCount / predYs.length;
180-
const precision = generateMicroMacroMetrics(predYs, trueYs);
181-
const microP = precision.microScore;
182-
const macroP = precision.macroScore;
183-
const recall = generateMicroMacroMetrics(trueYs, predYs);
184-
const microR = recall.microScore;
185-
const macroR = recall.macroScore;
186-
const microF1 = 2 * ((microP * microR) / (microP + microR)) || 0;
187-
const macroF1 = 2 * ((macroP * macroR) / (macroP + macroR)) || 0;
188-
189-
return [
190-
{
191-
key: TotalCohortSamples,
192-
label: localization.Interpret.Statistics.samples,
193-
stat: predYs.length
194-
},
195-
{
196-
key: ImageClassificationMetrics.Accuracy,
197-
label: localization.Interpret.Statistics.accuracy,
198-
stat: accuracy
199-
},
200-
{
201-
key: ImageClassificationMetrics.MicroPrecision,
202-
label: localization.Interpret.Statistics.precision,
203-
stat: microP
204-
},
205-
{
206-
key: ImageClassificationMetrics.MicroRecall,
207-
label: localization.Interpret.Statistics.recall,
208-
stat: microR
209-
},
210-
{
211-
key: ImageClassificationMetrics.MicroF1,
212-
label: localization.Interpret.Statistics.f1Score,
213-
stat: microF1
214-
},
215-
{
216-
key: ImageClassificationMetrics.MacroPrecision,
217-
label: localization.Interpret.Statistics.precision,
218-
stat: macroP
219-
},
220-
{
221-
key: ImageClassificationMetrics.MacroRecall,
222-
label: localization.Interpret.Statistics.recall,
223-
stat: macroR
224-
},
225-
{
226-
key: ImageClassificationMetrics.MacroF1,
227-
label: localization.Interpret.Statistics.f1Score,
228-
stat: macroF1
229-
}
230-
];
231-
};
232-
233171
export const generateMetrics: (
234172
jointDataset: JointDataset,
235173
selectionIndexes: number[][],
236174
modelType: ModelTypes,
237175
objectDetectionCache?: Map<string, [number, number, number]>,
238-
objectDetectionInputs?: [string, string, number]
176+
objectDetectionInputs?: [string, string, number],
177+
questionAnsweringCache?: QuestionAnsweringCacheType
239178
) => ILabeledStatistic[][] = (
240179
jointDataset: JointDataset,
241180
selectionIndexes: number[][],
242181
modelType: ModelTypes,
243182
objectDetectionCache?: Map<string, [number, number, number]>,
244-
objectDetectionInputs?: [string, string, number]
183+
objectDetectionInputs?: [string, string, number],
184+
questionAnsweringCache?: QuestionAnsweringCacheType
245185
): ILabeledStatistic[][] => {
246186
if (
247187
modelType === ModelTypes.ImageMultilabel ||
248188
modelType === ModelTypes.TextMultilabel
249189
) {
250190
return generateMultilabelStats(jointDataset, selectionIndexes);
251191
}
252-
if (modelType === ModelTypes.QuestionAnswering) {
253-
return generateQuestionAnsweringStats(selectionIndexes);
192+
if (modelType === ModelTypes.QuestionAnswering && questionAnsweringCache) {
193+
return generateQuestionAnsweringStats(
194+
selectionIndexes,
195+
questionAnsweringCache
196+
);
254197
}
255198
const trueYs = jointDataset.unwrap(JointDataset.TrueYLabel);
256199
const predYs = jointDataset.unwrap(JointDataset.PredictedYLabel);

libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ interface IModelOverviewProps {
6262
objectDetectionCache: Map<string, [number, number, number]>
6363
) => Promise<any[]>;
6464
requestQuestionAnsweringMetrics?: (
65-
selectionIndexes: number[][]
65+
selectionIndexes: number[][],
66+
questionAnsweringCache: Map<
67+
string,
68+
[number, number, number, number, number, number]
69+
>
6670
) => Promise<any[]>;
6771
}
6872

@@ -95,6 +99,10 @@ export class ModelOverview extends React.Component<
9599
IModelOverviewState
96100
> {
97101
public static contextType = ModelAssessmentContext;
102+
public questionAnsweringCache: Map<
103+
string,
104+
[number, number, number, number, number, number]
105+
> = new Map();
98106
public objectDetectionCache: Map<string, [number, number, number]> =
99107
new Map();
100108
public context: React.ContextType<typeof ModelAssessmentContext> =
@@ -610,7 +618,8 @@ export class ModelOverview extends React.Component<
610618
this.state.aggregateMethod,
611619
this.state.className,
612620
this.state.iouThreshold
613-
]
621+
],
622+
this.questionAnsweringCache
614623
);
615624

616625
this.setState({
@@ -715,6 +724,7 @@ export class ModelOverview extends React.Component<
715724
this.context
716725
.requestQuestionAnsweringMetrics(
717726
selectionIndexes,
727+
this.questionAnsweringCache,
718728
new AbortController().signal
719729
)
720730
.then((result) => {
@@ -734,6 +744,24 @@ export class ModelOverview extends React.Component<
734744
] of result.entries()) {
735745
const count = selectionIndexes[cohortIndex].length;
736746

747+
if (
748+
!this.questionAnsweringCache.has(
749+
selectionIndexes[cohortIndex].toString()
750+
)
751+
) {
752+
this.questionAnsweringCache.set(
753+
selectionIndexes[cohortIndex].toString(),
754+
[
755+
exactMatchRatio,
756+
f1Score,
757+
meteorScore,
758+
bleuScore,
759+
bertScore,
760+
rougeScore
761+
]
762+
);
763+
}
764+
737765
const updatedCohortMetricStats = [
738766
{
739767
key: TotalCohortSamples,
@@ -813,7 +841,8 @@ export class ModelOverview extends React.Component<
813841
this.state.aggregateMethod,
814842
this.state.className,
815843
this.state.iouThreshold
816-
]
844+
],
845+
this.questionAnsweringCache
817846
);
818847

819848
this.setState({

libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ export interface ITabsViewProps {
5959
) => Promise<any[]>;
6060
requestQuestionAnsweringMetrics?: (
6161
selectionIndexes: number[][],
62+
questionAnsweringCache: Map<
63+
string,
64+
[number, number, number, number, number, number]
65+
>,
6266
abortSignal: AbortSignal
6367
) => Promise<any[]>;
6468
requestDebugML?: (

0 commit comments

Comments
 (0)