|
| 1 | +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; |
| 2 | +import { McpServer, RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; |
| 3 | +import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js"; |
| 4 | +import { Server } from "../../src/server.js"; |
| 5 | +import { Session } from "../../src/session.js"; |
| 6 | +import { Telemetry } from "../../src/telemetry/telemetry.js"; |
| 7 | +import { config, UserConfig } from "../../src/config.js"; |
| 8 | +import { afterEach } from "node:test"; |
| 9 | +import { availableModels } from "./models/index.js"; |
| 10 | +import { ToolDefinition } from "./models/model.js"; |
| 11 | +import { zodToJsonSchema } from "zod-to-json-schema"; |
| 12 | + |
| 13 | +class ToolMock { |
| 14 | + readonly name: string; |
| 15 | + arguments: unknown; |
| 16 | + returns: unknown; |
| 17 | + wasCalledWith: unknown; |
| 18 | + |
| 19 | + constructor(name: string) { |
| 20 | + this.name = name; |
| 21 | + this.arguments = {}; |
| 22 | + this.returns = {}; |
| 23 | + } |
| 24 | + |
| 25 | + verifyCalled(args: unknown): this { |
| 26 | + this.arguments = args; |
| 27 | + return this; |
| 28 | + } |
| 29 | + |
| 30 | + thenReturn(value: unknown): this { |
| 31 | + this.returns = value; |
| 32 | + return this; |
| 33 | + } |
| 34 | + |
| 35 | + _wasCalledWith(args: unknown): this { |
| 36 | + this.wasCalledWith = args; |
| 37 | + return this; |
| 38 | + } |
| 39 | + |
| 40 | + _verify(): void { |
| 41 | + if (this.wasCalledWith) { |
| 42 | + expect(this.wasCalledWith).toEqual(this.arguments); |
| 43 | + } else { |
| 44 | + expect(this.arguments).not.toBe(null); |
| 45 | + } |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +interface McpServerUnsafe { |
| 50 | + mcpServer: McpServer; |
| 51 | +} |
| 52 | + |
| 53 | +type AccuracyToolSetupFunction = (toolName: string) => ToolMock; |
| 54 | +type AccuracyTestCaseFn = (tools: AccuracyToolSetupFunction) => void; |
| 55 | +type AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => void; |
| 56 | +type AccuracyTestSuite = { prompt: AccuracyItFn }; |
| 57 | + |
| 58 | +export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: AccuracyTestSuite) => void) { |
| 59 | + const models = availableModels(); |
| 60 | + if (models.length === 0) { |
| 61 | + throw new Error("No models available for accuracy tests."); |
| 62 | + } |
| 63 | + |
| 64 | + models.forEach((model) => { |
| 65 | + describe(`${model.name}: ${useCase}`, () => { |
| 66 | + let mcpServer: Server; |
| 67 | + let mcpClient: Client; |
| 68 | + let userConfig: UserConfig; |
| 69 | + let session: Session; |
| 70 | + let telemetry: Telemetry; |
| 71 | + |
| 72 | + beforeEach(async () => { |
| 73 | + mcpClient = new Client( |
| 74 | + { |
| 75 | + name: "test-client", |
| 76 | + version: "1.2.3", |
| 77 | + }, |
| 78 | + { |
| 79 | + capabilities: {}, |
| 80 | + } |
| 81 | + ); |
| 82 | + |
| 83 | + userConfig = { ...config }; |
| 84 | + session = new Session(userConfig); |
| 85 | + telemetry = Telemetry.create(session, userConfig); |
| 86 | + |
| 87 | + mcpServer = new Server({ |
| 88 | + session, |
| 89 | + userConfig, |
| 90 | + telemetry, |
| 91 | + mcpServer: new McpServer({ |
| 92 | + name: "test-server", |
| 93 | + version: "5.2.3", |
| 94 | + }), |
| 95 | + }); |
| 96 | + |
| 97 | + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); |
| 98 | + |
| 99 | + await Promise.all([mcpServer.connect(serverTransport), mcpClient.connect(clientTransport)]); |
| 100 | + }); |
| 101 | + |
| 102 | + afterEach(async () => { |
| 103 | + await Promise.all([mcpServer.close(), mcpClient.close()]); |
| 104 | + }); |
| 105 | + |
| 106 | + const promptFn: AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => { |
| 107 | + it(prompt, async () => { |
| 108 | + const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer; |
| 109 | + const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool }; |
| 110 | + const toolDefinitions = Object.entries(tools).map(([toolName, tool]) => { |
| 111 | + if (!tool.inputSchema) { |
| 112 | + throw new Error(`Tool ${toolName} does not have an input schema defined.`); |
| 113 | + } |
| 114 | + |
| 115 | + const toolForApi: ToolDefinition = { |
| 116 | + name: toolName, |
| 117 | + description: tool.description ?? "", |
| 118 | + parameters: zodToJsonSchema(tool.inputSchema, { |
| 119 | + target: "jsonSchema7", |
| 120 | + allowedAdditionalProperties: undefined, |
| 121 | + rejectedAdditionalProperties: undefined, |
| 122 | + postProcess: (schema) => { |
| 123 | + if (schema && typeof schema === "object") { |
| 124 | + return { |
| 125 | + ...schema, |
| 126 | + $schema: undefined, |
| 127 | + const: undefined, |
| 128 | + additionalProperties: undefined, |
| 129 | + }; |
| 130 | + } |
| 131 | + return schema; |
| 132 | + }, |
| 133 | + }), |
| 134 | + }; |
| 135 | + delete toolForApi.parameters.$schema; |
| 136 | + return toolForApi; |
| 137 | + }); |
| 138 | + |
| 139 | + const mocks: Array<ToolMock> = []; |
| 140 | + const toolFn: AccuracyToolSetupFunction = (toolName: string) => { |
| 141 | + const mock = new ToolMock(toolName); |
| 142 | + |
| 143 | + const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer; |
| 144 | + const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool }; |
| 145 | + |
| 146 | + if (tools[toolName] !== undefined) { |
| 147 | + tools[toolName].callback = ((args: unknown) => { |
| 148 | + mock._wasCalledWith(args); |
| 149 | + return mock.returns; |
| 150 | + }) as unknown as ToolCallback; |
| 151 | + } |
| 152 | + |
| 153 | + mocks.push(mock); |
| 154 | + return mock; |
| 155 | + }; |
| 156 | + |
| 157 | + testCase(toolFn); |
| 158 | + |
| 159 | + const consumePromptUntilNoMoreCall = async (prompt: string[]) => { |
| 160 | + const promptStr = prompt.join("\n"); |
| 161 | + const response = await model.generateContent(promptStr, toolDefinitions); |
| 162 | + |
| 163 | + if (response.toolCall.length > 0) { |
| 164 | + const toolCallResults = await Promise.all( |
| 165 | + response.toolCall.map((tc) => |
| 166 | + mcpClient.callTool({ |
| 167 | + name: tc.name, |
| 168 | + arguments: tc.args, |
| 169 | + }) |
| 170 | + ) |
| 171 | + ); |
| 172 | + const newPrompt = toolCallResults.flatMap((result) => |
| 173 | + (result.content as Array<{ text: string }>).map((c) => c.text) |
| 174 | + ); |
| 175 | + |
| 176 | + if (newPrompt.join("\n").trim().length > 0) { |
| 177 | + return consumePromptUntilNoMoreCall(newPrompt); |
| 178 | + } |
| 179 | + } |
| 180 | + }; |
| 181 | + |
| 182 | + await consumePromptUntilNoMoreCall([prompt]); |
| 183 | + mocks.forEach((mock) => mock._verify()); |
| 184 | + }); |
| 185 | + }; |
| 186 | + |
| 187 | + testCaseFn({ prompt: promptFn }); |
| 188 | + }); |
| 189 | + }); |
| 190 | +} |
0 commit comments