Skip to content

Commit cda8caa

Browse files
committed
chore: add example accuracy test
1 parent 54effbb commit cda8caa

File tree

7 files changed

+323
-3
lines changed

7 files changed

+323
-3
lines changed

package-lock.json

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
"check:types": "tsc --noEmit --project tsconfig.json",
3030
"reformat": "prettier --write .",
3131
"generate": "./scripts/generate.sh",
32-
"test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage"
32+
"test": "npm run test:unit && npm run test:integration",
33+
"test:accuracy": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage --testPathPattern=tests/accuracy",
34+
"test:unit": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage --testPathPattern=tests/unit",
35+
"test:integration": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage --testPathPattern=tests/integration"
3336
},
3437
"license": "Apache-2.0",
3538
"devDependencies": {
@@ -57,7 +60,8 @@
5760
"tsx": "^4.19.3",
5861
"typescript": "^5.8.2",
5962
"typescript-eslint": "^8.29.1",
60-
"yaml": "^2.7.1"
63+
"yaml": "^2.7.1",
64+
"zod-to-json-schema": "^3.24.5"
6165
},
6266
"dependencies": {
6367
"@modelcontextprotocol/sdk": "^1.11.2",
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { describeAccuracyTest } from "../test-sdk.js";
2+
3+
describeAccuracyTest("1 step find queries", ({ prompt }) => {
4+
prompt("find all users in database 'my' and collection 'users'", (tool) => {
5+
tool("find").verifyCalled({ database: "my", collection: "users", limit: 10 });
6+
});
7+
8+
prompt("find all red cards in database 'production' and collection 'cars'", (tool) => {
9+
tool("find").verifyCalled({ filter: { color: "red" }, database: "production", collection: "cars", limit: 10 });
10+
});
11+
12+
prompt("get 100 books in database 'prod' and collection 'books' where the author is J.R.R Tolkien", (tool) => {
13+
tool("find").verifyCalled({
14+
filter: { author: "J.R.R Tolkien" },
15+
database: "prod",
16+
collection: "books",
17+
limit: 100,
18+
});
19+
});
20+
});

tests/accuracy/models/gemini.ts

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import { ModelFacade, ToolCall, ToolDefinition } from "./model.js";
2+
3+
type GeminiModel = "gemini-2.0-flash" | "gemini-1.5-flash";
4+
5+
export class GeminiModelFacade implements ModelFacade {
6+
readonly name: GeminiModel;
7+
8+
constructor(modelName: GeminiModel) {
9+
this.name = modelName;
10+
}
11+
12+
available(): boolean {
13+
return process.env.MONGODB_MCP_TEST_GEMINI_API_KEY !== undefined;
14+
}
15+
16+
async generateContent(prompt: string, tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }> {
17+
const toolDefinitions = tools.map((tool) => ({
18+
name: tool.name,
19+
description: tool.description,
20+
parameters: tool.parameters || {},
21+
}));
22+
23+
const chatHistory = [{ role: "user", parts: [{ text: prompt }] }];
24+
const payload = {
25+
contents: chatHistory,
26+
tools: {
27+
function_declarations: [toolDefinitions],
28+
},
29+
};
30+
31+
const apiKey = process.env.MONGODB_MCP_TEST_GEMINI_API_KEY;
32+
const apiUrl = `https://generativelanguage.googleapis.com/v1beta/models/${this.name}:generateContent?key=${apiKey}`;
33+
34+
try {
35+
const response = await fetch(apiUrl, {
36+
method: "POST",
37+
headers: { "Content-Type": "application/json" },
38+
body: JSON.stringify(payload),
39+
});
40+
41+
if (!response.ok) {
42+
const errorData = await response.text();
43+
console.error(`[Gemini API Error] HTTP error! status: ${response.status}, data: ${errorData}`);
44+
return { toolCall: [], text: `Gemini API error: ${response.status}` };
45+
}
46+
47+
const result = (await response.json()) as {
48+
candidates: Array<{
49+
content: {
50+
parts: Array<{
51+
text?: string;
52+
functionCall?: {
53+
name: string;
54+
args: Record<string, unknown>;
55+
};
56+
}>;
57+
};
58+
}>;
59+
};
60+
61+
if (result.candidates && result.candidates.length > 0) {
62+
const firstPart = result.candidates[0]?.content.parts[0];
63+
if (firstPart?.functionCall) {
64+
return {
65+
toolCall: [
66+
{
67+
name: firstPart.functionCall.name,
68+
args: firstPart.functionCall.args,
69+
},
70+
],
71+
};
72+
} else if (firstPart?.text) {
73+
return { toolCall: [], text: firstPart.text };
74+
}
75+
}
76+
return { toolCall: [], text: "Gemini response was empty or unexpected." };
77+
} catch (error: unknown) {
78+
console.error("[Gemini API Fetch Error]", error);
79+
return { toolCall: [], text: `Error contacting Gemini LLM.` };
80+
}
81+
}
82+
}

tests/accuracy/models/index.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import { ModelFacade } from "./model.js";
2+
import { GeminiModelFacade } from "./gemini.js";
3+
4+
const ALL_MODELS: ModelFacade[] = [
5+
new GeminiModelFacade("gemini-2.0-flash"),
6+
new GeminiModelFacade("gemini-1.5-flash"),
7+
];
8+
9+
export function availableModels(): ModelFacade[] {
10+
return ALL_MODELS.filter((model) => model.available());
11+
}

tests/accuracy/models/model.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
export type ToolCall = { name: string; args: Record<string, unknown> };
2+
export type ToolDefinition = {
3+
name: string;
4+
description: string;
5+
parameters: Record<string, unknown>;
6+
};
7+
8+
export interface ModelFacade {
9+
name: string;
10+
available(): boolean;
11+
generateContent(prompt: string, tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }>;
12+
}

tests/accuracy/test-sdk.ts

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
2+
import { McpServer, RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js";
3+
import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js";
4+
import { Server } from "../../src/server.js";
5+
import { Session } from "../../src/session.js";
6+
import { Telemetry } from "../../src/telemetry/telemetry.js";
7+
import { config, UserConfig } from "../../src/config.js";
8+
import { afterEach } from "node:test";
9+
import { availableModels } from "./models/index.js";
10+
import { ToolDefinition } from "./models/model.js";
11+
import { zodToJsonSchema } from "zod-to-json-schema";
12+
13+
class ToolMock {
14+
readonly name: string;
15+
arguments: unknown;
16+
returns: unknown;
17+
wasCalledWith: unknown;
18+
19+
constructor(name: string) {
20+
this.name = name;
21+
this.arguments = {};
22+
this.returns = {};
23+
}
24+
25+
verifyCalled(args: unknown): this {
26+
this.arguments = args;
27+
return this;
28+
}
29+
30+
thenReturn(value: unknown): this {
31+
this.returns = value;
32+
return this;
33+
}
34+
35+
_wasCalledWith(args: unknown): this {
36+
this.wasCalledWith = args;
37+
return this;
38+
}
39+
40+
_verify(): void {
41+
if (this.wasCalledWith) {
42+
expect(this.wasCalledWith).toEqual(this.arguments);
43+
} else {
44+
expect(this.arguments).not.toBe(null);
45+
}
46+
}
47+
}
48+
49+
interface McpServerUnsafe {
50+
mcpServer: McpServer;
51+
}
52+
53+
type AccuracyToolSetupFunction = (toolName: string) => ToolMock;
54+
type AccuracyTestCaseFn = (tools: AccuracyToolSetupFunction) => void;
55+
type AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => void;
56+
type AccuracyTestSuite = { prompt: AccuracyItFn };
57+
58+
export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: AccuracyTestSuite) => void) {
59+
const models = availableModels();
60+
if (models.length === 0) {
61+
throw new Error("No models available for accuracy tests.");
62+
}
63+
64+
models.forEach((model) => {
65+
describe(`${model.name}: ${useCase}`, () => {
66+
let mcpServer: Server;
67+
let mcpClient: Client;
68+
let userConfig: UserConfig;
69+
let session: Session;
70+
let telemetry: Telemetry;
71+
72+
beforeEach(async () => {
73+
mcpClient = new Client(
74+
{
75+
name: "test-client",
76+
version: "1.2.3",
77+
},
78+
{
79+
capabilities: {},
80+
}
81+
);
82+
83+
userConfig = { ...config };
84+
session = new Session(userConfig);
85+
telemetry = Telemetry.create(session, userConfig);
86+
87+
mcpServer = new Server({
88+
session,
89+
userConfig,
90+
telemetry,
91+
mcpServer: new McpServer({
92+
name: "test-server",
93+
version: "5.2.3",
94+
}),
95+
});
96+
97+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
98+
99+
await Promise.all([mcpServer.connect(serverTransport), mcpClient.connect(clientTransport)]);
100+
});
101+
102+
afterEach(async () => {
103+
await Promise.all([mcpServer.close(), mcpClient.close()]);
104+
});
105+
106+
const promptFn: AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => {
107+
it(prompt, async () => {
108+
const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer;
109+
const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool };
110+
const toolDefinitions = Object.entries(tools).map(([toolName, tool]) => {
111+
if (!tool.inputSchema) {
112+
throw new Error(`Tool ${toolName} does not have an input schema defined.`);
113+
}
114+
115+
const toolForApi: ToolDefinition = {
116+
name: toolName,
117+
description: tool.description ?? "",
118+
parameters: zodToJsonSchema(tool.inputSchema, {
119+
target: "jsonSchema7",
120+
allowedAdditionalProperties: undefined,
121+
rejectedAdditionalProperties: undefined,
122+
postProcess: (schema) => {
123+
if (schema && typeof schema === "object") {
124+
return {
125+
...schema,
126+
$schema: undefined,
127+
const: undefined,
128+
additionalProperties: undefined,
129+
};
130+
}
131+
return schema;
132+
},
133+
}),
134+
};
135+
delete toolForApi.parameters.$schema;
136+
return toolForApi;
137+
});
138+
139+
const mocks: Array<ToolMock> = [];
140+
const toolFn: AccuracyToolSetupFunction = (toolName: string) => {
141+
const mock = new ToolMock(toolName);
142+
143+
const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer;
144+
const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool };
145+
146+
if (tools[toolName] !== undefined) {
147+
tools[toolName].callback = ((args: unknown) => {
148+
mock._wasCalledWith(args);
149+
return mock.returns;
150+
}) as unknown as ToolCallback;
151+
}
152+
153+
mocks.push(mock);
154+
return mock;
155+
};
156+
157+
testCase(toolFn);
158+
159+
const consumePromptUntilNoMoreCall = async (prompt: string[]) => {
160+
const promptStr = prompt.join("\n");
161+
const response = await model.generateContent(promptStr, toolDefinitions);
162+
163+
if (response.toolCall.length > 0) {
164+
const toolCallResults = await Promise.all(
165+
response.toolCall.map((tc) =>
166+
mcpClient.callTool({
167+
name: tc.name,
168+
arguments: tc.args,
169+
})
170+
)
171+
);
172+
const newPrompt = toolCallResults.flatMap((result) =>
173+
(result.content as Array<{ text: string }>).map((c) => c.text)
174+
);
175+
176+
if (newPrompt.join("\n").trim().length > 0) {
177+
return consumePromptUntilNoMoreCall(newPrompt);
178+
}
179+
}
180+
};
181+
182+
await consumePromptUntilNoMoreCall([prompt]);
183+
mocks.forEach((mock) => mock._verify());
184+
});
185+
};
186+
187+
testCaseFn({ prompt: promptFn });
188+
});
189+
});
190+
}

0 commit comments

Comments
 (0)