Skip to content

Commit bbe2b83

Browse files
committed
chore: add plans for multistep actions
1 parent cda8caa commit bbe2b83

File tree

7 files changed

+203
-15
lines changed

7 files changed

+203
-15
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import { describeAccuracyTest } from "../test-sdk.js";
2+
3+
describeAccuracyTest("1 step delete queries", ({ prompt }) => {
4+
prompt("delete all disabled users (disabled = true) in database 'my' and collection 'users'", (tool) => {
5+
tool("delete-many").verifyCalled({
6+
database: "my",
7+
collection: "users",
8+
filter: { disabled: true },
9+
});
10+
});
11+
});

tests/accuracy/1-step/simple-find-query.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ describeAccuracyTest("1 step find queries", ({ prompt }) => {
55
tool("find").verifyCalled({ database: "my", collection: "users", limit: 10 });
66
});
77

8-
prompt("find all red cards in database 'production' and collection 'cars'", (tool) => {
8+
prompt("find all red cars in database 'production' and collection 'cars'", (tool) => {
99
tool("find").verifyCalled({ filter: { color: "red" }, database: "production", collection: "cars", limit: 10 });
1010
});
1111

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import { describeAccuracyTest } from "../test-sdk.js";
2+
3+
describeAccuracyTest("1 step update queries", ({ prompt }) => {
4+
prompt("set all users with an empty email to disabled in database 'my' and collection 'users'", (tool) => {
5+
tool("update-many").verifyCalled({
6+
database: "my",
7+
collection: "users",
8+
filter: { email: "" },
9+
update: { $set: { disabled: true } },
10+
});
11+
});
12+
});
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import { describeAccuracyTest } from "../test-sdk.js";
2+
3+
describeAccuracyTest("2 step create collection", ({ prompt }) => {
4+
prompt(
5+
`
6+
create a new collection named 'users' in database 'my' and afterwards create a sample document with the following data:
7+
- username: "john_doe"
8+
9+
- password: "password123"
10+
- disabled: false
11+
`,
12+
(tool) => {
13+
tool("create-collection").verifyCalled({
14+
database: "my",
15+
collection: "users",
16+
});
17+
18+
tool("insert-many").verifyCalled({
19+
database: "my",
20+
collection: "users",
21+
documents: [
22+
{
23+
username: "john_doe",
24+
25+
password: "password123",
26+
disabled: false,
27+
},
28+
],
29+
});
30+
}
31+
);
32+
});

tests/accuracy/models/gemini.ts

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,79 @@ export class GeminiModelFacade implements ModelFacade {
1313
return process.env.MONGODB_MCP_TEST_GEMINI_API_KEY !== undefined;
1414
}
1515

16-
async generateContent(prompt: string, tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }> {
16+
async generatePlan(prompt: string, tools: ToolDefinition[]): Promise<string[]> {
17+
const planPrompt = `You are an expert MongoDB developer. Create a plan for the following task: \n ${prompt} \n Return the plan as a list of steps, as a JSON array. For example: [ "Step 1: ...", "Step 2: ...", "Step 3: ..." ]. Only return the JSON array, nothing else. Do not include any wrapper markdown or anything, just the plain JSON array.`;
18+
const chatHistory = [{ role: "user", parts: [{ text: planPrompt }] }];
19+
20+
const apiKey = process.env.MONGODB_MCP_TEST_GEMINI_API_KEY;
21+
const apiUrl = `https://generativelanguage.googleapis.com/v1beta/models/${this.name}:generateContent?key=${apiKey}`;
22+
23+
const toolDefinitions = tools.map((tool) => ({
24+
name: tool.name,
25+
description: tool.description,
26+
parameters: tool.parameters || {},
27+
}));
28+
29+
const payload = {
30+
contents: chatHistory,
31+
tools: {
32+
function_declarations: [toolDefinitions],
33+
},
34+
};
35+
36+
try {
37+
const response = await fetch(apiUrl, {
38+
method: "POST",
39+
headers: { "Content-Type": "application/json" },
40+
body: JSON.stringify(payload),
41+
});
42+
43+
if (!response.ok) {
44+
const errorData = await response.text();
45+
console.error(`[Gemini API Error] HTTP error! status: ${response.status}, data: ${errorData}`);
46+
return [];
47+
}
48+
49+
const result = (await response.json()) as {
50+
candidates: Array<{
51+
content: {
52+
parts: Array<{
53+
text?: string;
54+
functionCall?: {
55+
name: string;
56+
args: Record<string, unknown>;
57+
};
58+
}>;
59+
};
60+
}>;
61+
};
62+
63+
const responseString = result.candidates
64+
.flatMap((candidate) => candidate.content.parts.map((part) => part.text || ""))
65+
.join("")
66+
.replace("```json", "")
67+
.replace("```", "");
68+
69+
try {
70+
return JSON.parse(responseString) as string[];
71+
} catch (parseError) {
72+
console.error("[Gemini API JSON.parse Error]", responseString, parseError);
73+
}
74+
return [];
75+
} catch (error: unknown) {
76+
console.error("[Gemini API Fetch Error]", error);
77+
return [];
78+
}
79+
}
80+
81+
async generateContent(parts: string[], tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }> {
1782
const toolDefinitions = tools.map((tool) => ({
1883
name: tool.name,
1984
description: tool.description,
2085
parameters: tool.parameters || {},
2186
}));
2287

23-
const chatHistory = [{ role: "user", parts: [{ text: prompt }] }];
88+
const chatHistory = [{ role: "user", parts: parts.map((part) => ({ text: part })) }];
2489
const payload = {
2590
contents: chatHistory,
2691
tools: {

tests/accuracy/models/model.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ export type ToolDefinition = {
88
export interface ModelFacade {
99
name: string;
1010
available(): boolean;
11-
generateContent(prompt: string, tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }>;
11+
12+
generatePlan(prompt: string, tools: ToolDefinition[]): Promise<string[]>;
13+
generateContent(parts: string[], tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }>;
1214
}

tests/accuracy/test-sdk.ts

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { availableModels } from "./models/index.js";
1010
import { ToolDefinition } from "./models/model.js";
1111
import { zodToJsonSchema } from "zod-to-json-schema";
1212

13+
type ToolMockReturn = { content: Array<{ type: string; text: string }> };
1314
class ToolMock {
1415
readonly name: string;
1516
arguments: unknown;
@@ -27,7 +28,7 @@ class ToolMock {
2728
return this;
2829
}
2930

30-
thenReturn(value: unknown): this {
31+
thenReturn(value: ToolMockReturn): this {
3132
this.returns = value;
3233
return this;
3334
}
@@ -55,6 +56,36 @@ type AccuracyTestCaseFn = (tools: AccuracyToolSetupFunction) => void;
5556
type AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => void;
5657
type AccuracyTestSuite = { prompt: AccuracyItFn };
5758

59+
type NonMockedCallError = { tool: string; args: unknown };
60+
61+
function logVerbose(...args: unknown[]): void {
62+
if (process.env.MONGODB_MCP_TEST_VERBOSE === "true") {
63+
console.log(...args);
64+
}
65+
}
66+
67+
function printModelPlanIfVerbose(model: string, plan: string[]): void {
68+
logVerbose(model, "📝: ", plan.join("\n"));
69+
}
70+
71+
function testPromptIsVerbose(model: string, prompt: string): void {
72+
logVerbose(model, "📜: ", prompt);
73+
}
74+
75+
function modelSaidVerbose(model: string, response: string): void {
76+
if (response.length > 0) {
77+
logVerbose(model, "🗣️: ", response);
78+
}
79+
}
80+
81+
function modelToolCalledVerbose(model: string, toolCall: string, args: unknown): void {
82+
logVerbose(model, "🛠️: ", toolCall, JSON.stringify(args));
83+
}
84+
85+
function toolCallsReturnedVerbose(model: string, answer: string): void {
86+
logVerbose(model, "📋: ", answer);
87+
}
88+
5889
export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: AccuracyTestSuite) => void) {
5990
const models = availableModels();
6091
if (models.length === 0) {
@@ -105,8 +136,13 @@ export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: Ac
105136

106137
const promptFn: AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => {
107138
it(prompt, async () => {
139+
testPromptIsVerbose(model.name, prompt);
140+
108141
const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer;
109142
const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool };
143+
const mockedTools = new Set<string>();
144+
const nonMockedCallErrors = new Array<NonMockedCallError>();
145+
110146
const toolDefinitions = Object.entries(tools).map(([toolName, tool]) => {
111147
if (!tool.inputSchema) {
112148
throw new Error(`Tool ${toolName} does not have an input schema defined.`);
@@ -136,17 +172,22 @@ export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: Ac
136172
return toolForApi;
137173
});
138174

139-
const mocks: Array<ToolMock> = [];
175+
const plan = await model.generatePlan(prompt, toolDefinitions);
176+
printModelPlanIfVerbose(model.name, plan);
177+
178+
179+
const mocks: Array<ToolMock> = [];
140180
const toolFn: AccuracyToolSetupFunction = (toolName: string) => {
141181
const mock = new ToolMock(toolName);
182+
mockedTools.add(toolName);
142183

143184
const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer;
144185
const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool };
145186

146187
if (tools[toolName] !== undefined) {
147188
tools[toolName].callback = ((args: unknown) => {
148189
mock._wasCalledWith(args);
149-
return mock.returns;
190+
return Promise.resolve(mock.returns);
150191
}) as unknown as ToolCallback;
151192
}
152193

@@ -157,30 +198,55 @@ export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: Ac
157198
testCase(toolFn);
158199

159200
const consumePromptUntilNoMoreCall = async (prompt: string[]) => {
160-
const promptStr = prompt.join("\n");
161-
const response = await model.generateContent(promptStr, toolDefinitions);
201+
const response = await model.generateContent(prompt, toolDefinitions);
162202

203+
modelSaidVerbose(model.name, response.text || "<no text>");
163204
if (response.toolCall.length > 0) {
164205
const toolCallResults = await Promise.all(
165-
response.toolCall.map((tc) =>
166-
mcpClient.callTool({
206+
response.toolCall.map((tc) => {
207+
modelToolCalledVerbose(model.name, tc.name, tc.args);
208+
209+
if (!mockedTools.has(tc.name)) {
210+
nonMockedCallErrors.push({ tool: tc.name, args: tc.args });
211+
}
212+
213+
return mcpClient.callTool({
167214
name: tc.name,
168215
arguments: tc.args,
169-
})
170-
)
216+
});
217+
})
171218
);
172-
const newPrompt = toolCallResults.flatMap((result) =>
219+
220+
const responseParts = toolCallResults.flatMap((result) =>
173221
(result.content as Array<{ text: string }>).map((c) => c.text)
174222
);
175223

176-
if (newPrompt.join("\n").trim().length > 0) {
224+
const newPrompt = prompt.concat(responseParts);
225+
toolCallsReturnedVerbose(model.name, newPrompt.join("\n"));
226+
227+
if (responseParts.length > 0) {
177228
return consumePromptUntilNoMoreCall(newPrompt);
178229
}
179230
}
180231
};
181232

233+
for (const step of plan) {
234+
await consumePromptUntilNoMoreCall([ step ]);
235+
}
236+
182237
await consumePromptUntilNoMoreCall([prompt]);
238+
183239
mocks.forEach((mock) => mock._verify());
240+
if (nonMockedCallErrors.length > 0) {
241+
for (const call of nonMockedCallErrors) {
242+
console.error(
243+
`Non-mocked tool call detected: ${call.tool} with args:`,
244+
JSON.stringify(call.args, null, 2)
245+
);
246+
}
247+
248+
throw new Error("Non-mocked tool calls detected. Check the console for details.");
249+
}
184250
});
185251
};
186252

0 commit comments

Comments
 (0)