Skip to content

Commit 58c853a

Browse files
chore: integrate capturing accuracy snapshots
1 parent f589fb3 commit 58c853a

File tree

5 files changed

+95
-12
lines changed

5 files changed

+95
-12
lines changed

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
"check:types": "tsc --noEmit --project tsconfig.json",
3030
"reformat": "prettier --write .",
3131
"generate": "./scripts/generate.sh",
32-
"test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage"
32+
"test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage --testPathIgnorePatterns=/tests/accuracy/",
33+
"test:accuracy": "node --experimental-vm-modules node_modules/jest/bin/jest.js --testPathPattern tests/accuracy"
3334
},
3435
"license": "Apache-2.0",
3536
"devDependencies": {
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import fs from "fs/promises";
2+
import path from "path";
3+
import { z } from "zod";
4+
5+
export const SNAPSHOT_FILE_PATH = path.resolve(process.cwd(), "accuracy-snapshot.json");
6+
7+
export const AccuracySnapshotEntrySchema = z.object({
8+
datetime: z.string(),
9+
commit: z.string(),
10+
model: z.string(),
11+
suite: z.string(),
12+
test: z.string(),
13+
toolCallingAccuracy: z.number(),
14+
parameterAccuracy: z.number(),
15+
});
16+
17+
export type AccuracySnapshotEntry = z.infer<typeof AccuracySnapshotEntrySchema>;
18+
19+
export async function readSnapshot(): Promise<AccuracySnapshotEntry[]> {
20+
try {
21+
const raw = await fs.readFile(SNAPSHOT_FILE_PATH, "utf8");
22+
return AccuracySnapshotEntrySchema.array().parse(JSON.parse(raw));
23+
} catch (e: unknown) {
24+
if ((e as { code: string }).code === "ENOENT") {
25+
return [];
26+
}
27+
throw e;
28+
}
29+
}
30+
31+
function waitFor(ms: number) {
32+
return new Promise((resolve) => setTimeout(resolve, ms));
33+
}
34+
35+
export async function appendAccuracySnapshot(entry: AccuracySnapshotEntry): Promise<void> {
36+
AccuracySnapshotEntrySchema.parse(entry);
37+
38+
for (let attempt = 0; attempt < 5; attempt++) {
39+
try {
40+
const snapshot = await readSnapshot();
41+
snapshot.unshift(entry);
42+
const tmp = `${SNAPSHOT_FILE_PATH}~${Date.now()}`;
43+
await fs.writeFile(tmp, JSON.stringify(snapshot, null, 2));
44+
await fs.rename(tmp, SNAPSHOT_FILE_PATH);
45+
return;
46+
} catch (e) {
47+
if (attempt < 4) {
48+
await waitFor(100 + Math.random() * 200);
49+
} else {
50+
throw e;
51+
}
52+
}
53+
}
54+
}

tests/accuracy/sdk/describe-accuracy-tests.ts

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { discoverMongoDBTools, TestTools, MockedTools } from "./test-tools.js";
33
import { TestableModels } from "./models.js";
44
import { ExpectedToolCall, parameterMatchingAccuracyScorer, toolCallingAccuracyScorer } from "./accuracy-scorers.js";
55
import { Agent, getVercelToolCallingAgent } from "./agent.js";
6+
import { appendAccuracySnapshot } from "./accuracy-snapshot.js";
67

78
interface AccuracyTestConfig {
89
prompt: string;
@@ -15,6 +16,20 @@ export function describeAccuracyTests(
1516
models: TestableModels,
1617
accuracyTestConfigs: AccuracyTestConfig[]
1718
) {
19+
const accuracyDatetime = process.env.ACCURACY_DATETIME;
20+
if (!accuracyDatetime) {
21+
throw new Error("ACCURACY_DATETIME environment variable is not set");
22+
}
23+
const accuracyCommit = process.env.ACCURACY_COMMIT;
24+
if (!accuracyCommit) {
25+
throw new Error("ACCURACY_COMMIT environment variable is not set");
26+
}
27+
28+
if (!models.length) {
29+
console.warn(`No models available to test ${suiteName}`);
30+
return;
31+
}
32+
1833
const eachModel = describe.each(models);
1934
const eachTest = it.each(accuracyTestConfigs);
2035

@@ -35,15 +50,30 @@ export function describeAccuracyTests(
3550
eachTest("$prompt", async function (testConfig) {
3651
testTools.mockTools(testConfig.mockedTools);
3752
const conversation = await agent.prompt(testConfig.prompt, model, testTools.vercelAiTools());
38-
console.log("conversation", conversation);
3953
const toolCalls = testTools.getToolCalls();
40-
console.log("?????? toolCalls", toolCalls);
41-
console.log("???? expected", testConfig.expectedToolCalls);
4254
const toolCallingAccuracy = toolCallingAccuracyScorer(testConfig.expectedToolCalls, toolCalls);
4355
const parameterMatchingAccuracy = parameterMatchingAccuracyScorer(testConfig.expectedToolCalls, toolCalls);
56+
await appendAccuracySnapshot({
57+
datetime: accuracyDatetime,
58+
commit: accuracyCommit,
59+
model: model.modelName,
60+
suite: suiteName,
61+
test: testConfig.prompt,
62+
toolCallingAccuracy,
63+
parameterAccuracy: parameterMatchingAccuracy,
64+
});
4465

45-
expect(toolCallingAccuracy).not.toEqual(0);
46-
expect(parameterMatchingAccuracy).toBeGreaterThanOrEqual(0.5);
66+
try {
67+
expect(toolCallingAccuracy).not.toEqual(0);
68+
expect(parameterMatchingAccuracy).toBeGreaterThanOrEqual(0.5);
69+
} catch (error) {
70+
console.warn(`Accuracy test failed for ${model.modelName} - ${suiteName} - ${testConfig.prompt}`);
71+
console.warn(`Conversation`, JSON.stringify(conversation, null, 2));
72+
console.warn(`Tool calls`, JSON.stringify(toolCalls, null, 2));
73+
console.warn(`Tool calling accuracy`, toolCallingAccuracy);
74+
console.warn(`Parameter matching accuracy`, parameterMatchingAccuracy);
75+
throw error;
76+
}
4777
});
4878
});
4979
}

tests/accuracy/sdk/models.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { createGoogleGenerativeAI } from "@himanshusinghs/google";
33
import { ollama } from "ollama-ai-provider";
44

55
export interface Model<P extends LanguageModelV1 = LanguageModelV1> {
6+
readonly modelName: string;
67
isAvailable(): boolean;
78
getModel(): P;
89
}
@@ -25,7 +26,7 @@ export class OllamaModel implements Model {
2526
constructor(readonly modelName: string) {}
2627

2728
isAvailable(): boolean {
28-
return true;
29+
return false;
2930
}
3031

3132
getModel() {
@@ -35,8 +36,8 @@ export class OllamaModel implements Model {
3536

3637
const ALL_TESTABLE_MODELS = [
3738
new GeminiModel("gemini-1.5-flash"),
38-
// new GeminiModel("gemini-2.0-flash"),
39-
// new OllamaModel("qwen3:latest"),
39+
new GeminiModel("gemini-2.0-flash"),
40+
new OllamaModel("qwen3:latest"),
4041
];
4142

4243
export type TestableModels = ReturnType<typeof getAvailableModels>;

tests/accuracy/sdk/test-tools.ts

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,6 @@ export async function discoverMongoDBTools(): Promise<Tool[]> {
132132
await mcpClient.connect(clientTransport);
133133

134134
return (await mcpClient.listTools()).tools;
135-
} catch (error: unknown) {
136-
console.error("Unexpected error occured", error);
137-
return [];
138135
} finally {
139136
await mcpClient?.close();
140137
await mcpServer?.session?.close();

0 commit comments

Comments
 (0)