Skip to content

Commit 7cd61aa

Browse files
chore: simplified toolCallingAccuracy calculation
1 parent 312b2a5 commit 7cd61aa

File tree

8 files changed

+109
-136
lines changed

8 files changed

+109
-136
lines changed

package-lock.json

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"jest-environment-node": "^29.7.0",
5757
"jest-extended": "^6.0.0",
5858
"json-schema": "^0.4.0",
59+
"microdiff": "^1.5.0",
5960
"mongodb-runner": "^5.8.2",
6061
"ollama-ai-provider": "^1.2.0",
6162
"openapi-types": "^12.1.3",
Lines changed: 41 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,133 +1,60 @@
1-
export type ToolCall = {
2-
toolCallId: string;
3-
toolName: string;
4-
parameters: unknown;
5-
};
6-
export type ExpectedToolCall = Omit<ToolCall, "toolCallId">;
1+
import diff from "microdiff";
2+
import { ExpectedToolCall, ActualToolCall } from "./accuracy-snapshot-storage/snapshot-storage.js";
73

8-
export function toolCallingAccuracyScorer(expectedToolCalls: ExpectedToolCall[], actualToolCalls: ToolCall[]): number {
9-
if (actualToolCalls.length < expectedToolCalls.length) {
10-
return 0;
11-
}
12-
13-
const possibleScore = actualToolCalls.length > expectedToolCalls.length ? 0.75 : 1;
14-
const checkedToolCallIds = new Set<string>();
15-
for (const expectedToolCall of expectedToolCalls) {
16-
const matchingActualToolCall = actualToolCalls.find(
17-
(actualToolCall) =>
18-
actualToolCall.toolName === expectedToolCall.toolName &&
19-
!checkedToolCallIds.has(actualToolCall.toolCallId)
20-
);
21-
22-
if (!matchingActualToolCall) {
23-
return 0;
24-
}
25-
26-
checkedToolCallIds.add(matchingActualToolCall.toolCallId);
27-
}
28-
29-
return possibleScore;
30-
}
31-
32-
export function parameterMatchingAccuracyScorer(
4+
export function calculateToolCallingAccuracy(
335
expectedToolCalls: ExpectedToolCall[],
34-
actualToolCalls: ToolCall[]
6+
actualToolCalls: ActualToolCall[]
357
): number {
368
if (expectedToolCalls.length === 0) {
37-
return 1;
9+
return actualToolCalls.length === 0 ? 1 : 0.75;
3810
}
3911

40-
const usedActualIndexes = new Set<number>();
41-
const scores: number[] = [];
12+
const maxAccuracy = actualToolCalls.length > expectedToolCalls.length ? 0.75 : 1;
13+
14+
const individualAccuracies: number[] = [];
15+
const checkedActualToolCallIndexes = new Set<number>();
4216

4317
for (const expectedCall of expectedToolCalls) {
44-
// Find all unmatched actual tool calls with the same tool name
4518
const candidates = actualToolCalls
4619
.map((call, index) => ({ call, index }))
47-
.filter(({ call, index }) => !usedActualIndexes.has(index) && call.toolName === expectedCall.toolName);
48-
49-
if (candidates.length === 0) {
50-
scores.push(0);
51-
continue;
52-
}
53-
54-
// Pick the candidate with the best parameter match
55-
let bestScore = -1;
56-
let bestIndex = -1;
57-
for (const { call, index } of candidates) {
58-
const score = compareParams(expectedCall.parameters, call.parameters);
59-
if (score > bestScore) {
60-
bestScore = score;
61-
bestIndex = index;
62-
}
63-
}
64-
65-
usedActualIndexes.add(bestIndex);
66-
scores.push(bestScore);
67-
}
68-
69-
const totalScore = scores.reduce((sum, score) => sum + score, 0);
70-
return totalScore / scores.length;
20+
.filter(
21+
({ call, index }) => !checkedActualToolCallIndexes.has(index) && call.toolName === expectedCall.toolName
22+
)
23+
.map(({ call, index }) => ({
24+
call,
25+
index,
26+
score: compareParams(expectedCall.parameters, call.parameters),
27+
}))
28+
.filter(({ score }) => score >= 0.75)
29+
.sort((a, b) => b.score - a.score);
30+
31+
const bestMatch = candidates[0];
32+
if (!bestMatch) {
33+
individualAccuracies.push(0);
34+
} else {
35+
checkedActualToolCallIndexes.add(bestMatch.index);
36+
const individualAccuracy = Math.min(bestMatch.score, maxAccuracy);
37+
individualAccuracies.push(individualAccuracy);
38+
}
39+
}
40+
41+
return Math.min(...individualAccuracies);
7142
}
7243

73-
/**
74-
* Recursively compares expected and actual parameters and returns a score.
75-
* - 1: Perfect match.
76-
* - 0.75: All expected parameters are present and match, but there are extra actual parameters.
77-
* - 0: Missing parameters or mismatched values.
78-
*/
79-
function compareParams(expected: unknown, actual: unknown): number {
80-
if (expected === null || expected === undefined) {
81-
return actual === null || actual === undefined ? 1 : 0;
82-
}
83-
if (actual === null || actual === undefined) {
84-
return 0;
85-
}
44+
function compareParams(expected: Record<string, unknown>, actual: Record<string, unknown>): number {
45+
const differences = diff(expected, actual);
8646

87-
if (Array.isArray(expected)) {
88-
if (!Array.isArray(actual) || actual.length < expected.length) {
89-
return 0;
90-
}
91-
let minScore = 1;
92-
for (let i = 0; i < expected.length; i++) {
93-
minScore = Math.min(minScore, compareParams(expected[i], actual[i]));
94-
}
95-
if (minScore === 0) {
96-
return 0;
97-
}
98-
if (actual.length > expected.length) {
99-
minScore = Math.min(minScore, 0.75);
100-
}
101-
return minScore;
47+
if (differences.length === 0) {
48+
return 1;
10249
}
10350

104-
if (typeof expected === "object") {
105-
if (typeof actual !== "object" || Array.isArray(actual)) {
106-
return 0;
107-
}
108-
const expectedKeys = Object.keys(expected as Record<string, unknown>);
109-
const actualKeys = Object.keys(actual as Record<string, unknown>);
110-
111-
let minScore = 1;
112-
for (const key of expectedKeys) {
113-
if (!Object.prototype.hasOwnProperty.call(actual, key)) {
114-
return 0;
115-
}
116-
minScore = Math.min(
117-
minScore,
118-
compareParams((expected as Record<string, unknown>)[key], (actual as Record<string, unknown>)[key])
119-
);
120-
}
51+
const hasOnlyAdditions = differences.every((d) => d.type === "CREATE");
52+
const hasRemovals = differences.some((d) => d.type === "REMOVE");
53+
const hasChanges = differences.some((d) => d.type === "CHANGE");
12154

122-
if (minScore === 0) {
123-
return 0;
124-
}
125-
126-
if (actualKeys.length > expectedKeys.length) {
127-
minScore = Math.min(minScore, 0.75);
128-
}
129-
return minScore;
55+
if (hasOnlyAdditions && !hasRemovals && !hasChanges) {
56+
return 0.75;
13057
}
13158

132-
return expected == actual ? 1 : 0;
59+
return 0;
13360
}

tests/accuracy/sdk/accuracy-snapshot-storage/mdb-snapshot-storage.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ export class MongoDBSnapshotStorage implements AccuracySnapshotStorage {
2828
async createSnapshotEntry(
2929
snapshotEntry: Pick<
3030
AccuracySnapshotEntry,
31+
| "provider"
3132
| "requestedModel"
3233
| "test"
3334
| "prompt"
3435
| "toolCallingAccuracy"
35-
| "parameterAccuracy"
36+
| "expectedToolCalls"
37+
| "actualToolCalls"
3638
| "llmResponseTime"
3739
| "tokensUsage"
3840
| "respondingModel"

tests/accuracy/sdk/accuracy-snapshot-storage/snapshot-storage.ts

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
11
import z from "zod";
22

3+
const ExpectedToolCallSchema = z.object({
4+
toolCallId: z.string(),
5+
toolName: z.string(),
6+
parameters: z.record(z.string(), z.unknown()),
7+
});
8+
9+
const ActualToolCallSchema = ExpectedToolCallSchema.omit({ toolCallId: undefined });
10+
11+
export type ExpectedToolCall = z.infer<typeof ExpectedToolCallSchema>;
12+
export type ActualToolCall = z.infer<typeof ActualToolCallSchema>;
13+
314
export const AccuracySnapshotEntrySchema = z.object({
415
// Git and meta information for snapshot entries
516
accuracyRunId: z.string(),
617
createdOn: z.number(),
718
commitSHA: z.string(),
819
// Accuracy info
20+
provider: z.string(),
921
requestedModel: z.string(),
1022
test: z.string(),
1123
prompt: z.string(),
1224
toolCallingAccuracy: z.number(),
13-
parameterAccuracy: z.number(),
25+
// debug info for further investigations
26+
expectedToolCalls: ExpectedToolCallSchema.array(),
27+
actualToolCalls: ActualToolCallSchema.array(),
1428
llmResponseTime: z.number(),
1529
tokensUsage: z
1630
.object({
@@ -30,11 +44,13 @@ export interface AccuracySnapshotStorage {
3044
createSnapshotEntry(
3145
snapshotEntry: Pick<
3246
AccuracySnapshotEntry,
47+
| "provider"
3348
| "requestedModel"
3449
| "test"
3550
| "prompt"
3651
| "toolCallingAccuracy"
37-
| "parameterAccuracy"
52+
| "expectedToolCalls"
53+
| "actualToolCalls"
3854
| "llmResponseTime"
3955
| "tokensUsage"
4056
| "respondingModel"

tests/accuracy/sdk/accuracy-testing-client.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { experimental_createMCPClient as createMCPClient, tool as createVercelTo
55
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
66
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
77

8-
import { ToolCall } from "./accuracy-scorers.js";
8+
import { ExpectedToolCall } from "./accuracy-snapshot-storage/snapshot-storage.js";
99

1010
const __dirname = fileURLToPath(import.meta.url);
1111
const distPath = path.join(__dirname, "..", "..", "..", "..", "dist");
@@ -16,7 +16,7 @@ export type MockedTools = Record<string, ToolResultGeneratorFn>;
1616

1717
export class AccuracyTestingClient {
1818
private mockedTools: MockedTools = {};
19-
private recordedToolCalls: ToolCall[] = [];
19+
private recordedToolCalls: ExpectedToolCall[] = [];
2020
private constructor(private readonly vercelMCPClient: Awaited<ReturnType<typeof createMCPClient>>) {}
2121

2222
async close() {
@@ -33,7 +33,7 @@ export class AccuracyTestingClient {
3333
this.recordedToolCalls.push({
3434
toolCallId: uuid(),
3535
toolName: toolName,
36-
parameters: args,
36+
parameters: args as Record<string, unknown>,
3737
});
3838
try {
3939
const toolResultGeneratorFn = this.mockedTools[toolName];

tests/accuracy/sdk/describe-accuracy-tests.ts

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import { TestableModels } from "./models.js";
2-
import { ExpectedToolCall, parameterMatchingAccuracyScorer, toolCallingAccuracyScorer } from "./accuracy-scorers.js";
2+
import { calculateToolCallingAccuracy } from "./accuracy-scorers.js";
33
import { getVercelToolCallingAgent, VercelAgent } from "./agent.js";
44
import { prepareTestData, setupMongoDBIntegrationTest } from "../../integration/tools/mongodb/mongodbHelpers.js";
55
import { AccuracyTestingClient, MockedTools } from "./accuracy-testing-client.js";
66
import { getAccuracySnapshotStorage } from "./accuracy-snapshot-storage/get-snapshot-storage.js";
7-
import { AccuracySnapshotStorage } from "./accuracy-snapshot-storage/snapshot-storage.js";
7+
import { AccuracySnapshotStorage, ExpectedToolCall } from "./accuracy-snapshot-storage/snapshot-storage.js";
88

99
export interface AccuracyTestConfig {
1010
systemPrompt?: string;
@@ -33,7 +33,7 @@ export function describeAccuracyTests(
3333
const eachModel = describe.each(models);
3434
const eachSuite = describe.each(Object.keys(accuracyTestConfigs));
3535

36-
eachModel(`$modelName`, function (model) {
36+
eachModel(`$displayName`, function (model) {
3737
const mdbIntegration = setupMongoDBIntegrationTest();
3838
const { populateTestData, cleanupTestDatabases } = prepareTestData(mdbIntegration);
3939

@@ -72,20 +72,18 @@ export function describeAccuracyTests(
7272
const result = await agent.prompt(promptForModel, model, toolsForModel);
7373
const timeAfterPrompt = Date.now();
7474
const toolCalls = testMCPClient.getToolCalls();
75-
const toolCallingAccuracy = toolCallingAccuracyScorer(testConfig.expectedToolCalls, toolCalls);
76-
const parameterMatchingAccuracy = parameterMatchingAccuracyScorer(
77-
testConfig.expectedToolCalls,
78-
toolCalls
79-
);
75+
const toolCallingAccuracy = calculateToolCallingAccuracy(testConfig.expectedToolCalls, toolCalls);
8076

8177
const responseTime = timeAfterPrompt - timeBeforePrompt;
8278
await accuracySnapshotStorage.createSnapshotEntry({
79+
provider: model.provider,
8380
requestedModel: model.modelName,
8481
test: suiteName,
8582
prompt: testConfig.prompt,
8683
llmResponseTime: responseTime,
87-
toolCallingAccuracy,
88-
parameterAccuracy: parameterMatchingAccuracy,
84+
toolCallingAccuracy: toolCallingAccuracy,
85+
actualToolCalls: toolCalls,
86+
expectedToolCalls: testConfig.expectedToolCalls,
8987
...result,
9088
});
9189
});

0 commit comments

Comments
 (0)