From 2c7351ffb9ffc87e3be3e46ab98aea44a0d8fcbd Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Tue, 2 Sep 2025 13:15:00 +0200 Subject: [PATCH 1/5] chore: allows injecting a connection error handler We're expanding the transport runner interface to accept a custom connection error handler so that library consumers can provide their own mechanism of handling connection error and have an opportunity to act on them. --- src/common/connectionErrorHandler.ts | 81 +++++ src/common/errors.ts | 4 +- src/index.ts | 9 +- src/server.ts | 6 +- src/tools/mongodb/mongodbTool.ts | 69 +---- src/transports/base.ts | 31 +- src/transports/stdio.ts | 14 +- src/transports/streamableHttp.ts | 14 +- tests/integration/helpers.ts | 2 + .../tools/mongodb/mongodbTool.test.ts | 290 ++++++++++++++++++ .../transports/streamableHttp.test.ts | 10 +- 11 files changed, 442 insertions(+), 88 deletions(-) create mode 100644 src/common/connectionErrorHandler.ts create mode 100644 tests/integration/tools/mongodb/mongodbTool.test.ts diff --git a/src/common/connectionErrorHandler.ts b/src/common/connectionErrorHandler.ts new file mode 100644 index 000000000..e1ddce66a --- /dev/null +++ b/src/common/connectionErrorHandler.ts @@ -0,0 +1,81 @@ +import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import { ErrorCodes, type MongoDBError } from "./errors.js"; +import type { AnyConnectionState } from "./connectionManager.js"; +import type { ToolBase } from "../tools/tool.js"; + +export type ConnectionErrorHandler = ( + error: MongoDBError, + additionalContext: { + availableTools: ToolBase[]; + connectionState: AnyConnectionState; + } +) => { errorHandled: false } | { errorHandled: true; result: CallToolResult }; + +export const connectionErrorHandler: ConnectionErrorHandler = (error, { availableTools, connectionState }) => { + const connectTools = availableTools + .filter((t) => t.operationType === "connect") + .sort((a, b) => a.category.localeCompare(b.category)); // Sort Altas tools before MongoDB tools + + // Find the first Atlas connect tool if available and suggest to the LLM to use it. + // Note: if we ever have multiple Atlas connect tools, we may want to refine this logic to select the most appropriate one. + const atlasConnectTool = connectTools?.find((t) => t.category === "atlas"); + const llmConnectHint = atlasConnectTool + ? `Note to LLM: prefer using the "${atlasConnectTool.name}" tool to connect to an Atlas cluster over using a connection string. Make sure to ask the user to specify a cluster name they want to connect to or ask them if they want to use the "list-clusters" tool to list all their clusters. Do not invent cluster names or connection strings unless the user has explicitly specified them. If they've previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same cluster/connection.` + : "Note to LLM: do not invent connection strings and explicitly ask the user to provide one. If they have previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same connection string."; + + const connectToolsNames = connectTools?.map((t) => `"${t.name}"`).join(", "); + const additionalPromptForConnectivity: { type: "text"; text: string }[] = []; + + if (connectionState.tag === "connecting" && connectionState.oidcConnectionType) { + additionalPromptForConnectivity.push({ + type: "text", + text: `The user needs to finish their OIDC connection by opening '${connectionState.oidcLoginUrl}' in the browser and use the following user code: '${connectionState.oidcUserCode}'`, + }); + } else { + additionalPromptForConnectivity.push({ + type: "text", + text: connectToolsNames + ? `Please use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance or update the MCP server configuration to include a connection string. ${llmConnectHint}` + : "There are no tools available to connect. Please update the configuration to include a connection string and restart the server.", + }); + } + + switch (error.code) { + case ErrorCodes.NotConnectedToMongoDB: + return { + errorHandled: true, + result: { + content: [ + { + type: "text", + text: "You need to connect to a MongoDB instance before you can access its data.", + }, + ...additionalPromptForConnectivity, + ], + isError: true, + }, + }; + case ErrorCodes.MisconfiguredConnectionString: + return { + errorHandled: true, + result: { + content: [ + { + type: "text", + text: "The configured connection string is not valid. Please check the connection string and confirm it points to a valid MongoDB instance.", + }, + { + type: "text", + text: connectTools + ? `Alternatively, you can use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance. ${llmConnectHint}` + : "Please update the configuration to use a valid connection string and restart the server.", + }, + ], + isError: true, + }, + }; + + default: + return { errorHandled: false }; + } +}; diff --git a/src/common/errors.ts b/src/common/errors.ts index d81867f09..084d45ca7 100644 --- a/src/common/errors.ts +++ b/src/common/errors.ts @@ -4,9 +4,9 @@ export enum ErrorCodes { ForbiddenCollscan = 1_000_002, } -export class MongoDBError extends Error { +export class MongoDBError extends Error { constructor( - public code: ErrorCodes, + public code: ErrorCode, message: string ) { super(message); diff --git a/src/index.ts b/src/index.ts index 6a7150e35..f7699518a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -49,7 +49,14 @@ async function main(): Promise { assertHelpMode(); assertVersionMode(); - const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config); + const transportRunner = + config.transport === "stdio" + ? new StdioRunner({ + userConfig: config, + }) + : new StreamableHttpRunner({ + userConfig: config, + }); const shutdown = (): void => { transportRunner.logger.info({ id: LogId.serverCloseRequested, diff --git a/src/server.ts b/src/server.ts index 76399e986..914a823b9 100644 --- a/src/server.ts +++ b/src/server.ts @@ -21,12 +21,14 @@ import assert from "assert"; import type { ToolBase } from "./tools/tool.js"; import { validateConnectionString } from "./helpers/connectionOptions.js"; import { packageInfo } from "./common/packageInfo.js"; +import { type ConnectionErrorHandler } from "./common/connectionErrorHandler.js"; export interface ServerOptions { session: Session; userConfig: UserConfig; mcpServer: McpServer; telemetry: Telemetry; + connectionErrorHandler: ConnectionErrorHandler; } export class Server { @@ -35,6 +37,7 @@ export class Server { private readonly telemetry: Telemetry; public readonly userConfig: UserConfig; public readonly tools: ToolBase[] = []; + public readonly connectionErrorHandler: ConnectionErrorHandler; private _mcpLogLevel: LogLevel = "debug"; @@ -45,12 +48,13 @@ export class Server { private readonly startTime: number; private readonly subscriptions = new Set(); - constructor({ session, mcpServer, userConfig, telemetry }: ServerOptions) { + constructor({ session, mcpServer, userConfig, telemetry, connectionErrorHandler }: ServerOptions) { this.startTime = Date.now(); this.session = session; this.telemetry = telemetry; this.mcpServer = mcpServer; this.userConfig = userConfig; + this.connectionErrorHandler = connectionErrorHandler; } async connect(transport: Transport): Promise { diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index 5fff778a6..fd1981818 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -56,63 +56,22 @@ export abstract class MongoDBToolBase extends ToolBase { args: ToolArgs ): Promise | CallToolResult { if (error instanceof MongoDBError) { - const connectTools = this.server?.tools - .filter((t) => t.operationType === "connect") - .sort((a, b) => a.category.localeCompare(b.category)); // Sort Altas tools before MongoDB tools - - // Find the first Atlas connect tool if available and suggest to the LLM to use it. - // Note: if we ever have multiple Atlas connect tools, we may want to refine this logic to select the most appropriate one. - const atlasConnectTool = connectTools?.find((t) => t.category === "atlas"); - const llmConnectHint = atlasConnectTool - ? `Note to LLM: prefer using the "${atlasConnectTool.name}" tool to connect to an Atlas cluster over using a connection string. Make sure to ask the user to specify a cluster name they want to connect to or ask them if they want to use the "list-clusters" tool to list all their clusters. Do not invent cluster names or connection strings unless the user has explicitly specified them. If they've previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same cluster/connection.` - : "Note to LLM: do not invent connection strings and explicitly ask the user to provide one. If they have previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same connection string."; - - const connectToolsNames = connectTools?.map((t) => `"${t.name}"`).join(", "); - const connectionStatus = this.session.connectionManager.currentConnectionState; - const additionalPromptForConnectivity: { type: "text"; text: string }[] = []; - - if (connectionStatus.tag === "connecting" && connectionStatus.oidcConnectionType) { - additionalPromptForConnectivity.push({ - type: "text", - text: `The user needs to finish their OIDC connection by opening '${connectionStatus.oidcLoginUrl}' in the browser and use the following user code: '${connectionStatus.oidcUserCode}'`, - }); - } else { - additionalPromptForConnectivity.push({ - type: "text", - text: connectToolsNames - ? `Please use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance or update the MCP server configuration to include a connection string. ${llmConnectHint}` - : "There are no tools available to connect. Please update the configuration to include a connection string and restart the server.", - }); - } - switch (error.code) { case ErrorCodes.NotConnectedToMongoDB: - return { - content: [ - { - type: "text", - text: "You need to connect to a MongoDB instance before you can access its data.", - }, - ...additionalPromptForConnectivity, - ], - isError: true, - }; - case ErrorCodes.MisconfiguredConnectionString: - return { - content: [ - { - type: "text", - text: "The configured connection string is not valid. Please check the connection string and confirm it points to a valid MongoDB instance.", - }, - { - type: "text", - text: connectTools - ? `Alternatively, you can use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance. ${llmConnectHint}` - : "Please update the configuration to use a valid connection string and restart the server.", - }, - ], - isError: true, - }; + case ErrorCodes.MisconfiguredConnectionString: { + const connectionError = error as MongoDBError< + ErrorCodes.NotConnectedToMongoDB | ErrorCodes.MisconfiguredConnectionString + >; + const outcome = this.server?.connectionErrorHandler(connectionError, { + availableTools: this.server?.tools ?? [], + connectionState: this.session.connectionManager.currentConnectionState, + }); + if (outcome?.errorHandled) { + return outcome.result; + } else { + return super.handleError(error, args); + } + } case ErrorCodes.ForbiddenCollscan: return { content: [ diff --git a/src/transports/base.ts b/src/transports/base.ts index d6fc53adc..6fe198c21 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -8,17 +8,35 @@ import type { LoggerBase } from "../common/logger.js"; import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js"; import { ExportsManager } from "../common/exportsManager.js"; import { DeviceId } from "../helpers/deviceId.js"; -import { type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; +import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; +import { + type ConnectionErrorHandler, + connectionErrorHandler as defaultConnectionErrorHandler, +} from "../common/connectionErrorHandler.js"; + +export type TransportRunnerConfig = { + userConfig: UserConfig; + createConnectionManager?: ConnectionManagerFactoryFn; + connectionErrorHandler?: ConnectionErrorHandler; + additionalLoggers?: LoggerBase[]; +}; export abstract class TransportRunnerBase { public logger: LoggerBase; public deviceId: DeviceId; + protected readonly userConfig: UserConfig; + private readonly createConnectionManager: ConnectionManagerFactoryFn; + private readonly connectionErrorHandler: ConnectionErrorHandler; - protected constructor( - protected readonly userConfig: UserConfig, - private readonly createConnectionManager: ConnectionManagerFactoryFn, - additionalLoggers: LoggerBase[] - ) { + protected constructor({ + userConfig, + createConnectionManager, + connectionErrorHandler, + additionalLoggers = [], + }: TransportRunnerConfig) { + this.userConfig = userConfig; + this.createConnectionManager = createConnectionManager ?? createMCPConnectionManager; + this.connectionErrorHandler = connectionErrorHandler ?? defaultConnectionErrorHandler; const loggers: LoggerBase[] = [...additionalLoggers]; if (this.userConfig.loggers.includes("stderr")) { loggers.push(new ConsoleLogger()); @@ -68,6 +86,7 @@ export abstract class TransportRunnerBase { session, telemetry, userConfig: this.userConfig, + connectionErrorHandler: this.connectionErrorHandler, }); // We need to create the MCP logger after the server is constructed diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 4ed941ef9..09a7490b9 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -2,11 +2,9 @@ import { EJSON } from "bson"; import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import { type LoggerBase, LogId } from "../common/logger.js"; +import { LogId } from "../common/logger.js"; import type { Server } from "../server.js"; -import { TransportRunnerBase } from "./base.js"; -import { type UserConfig } from "../common/config.js"; -import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; +import { TransportRunnerBase, type TransportRunnerConfig } from "./base.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk // but it uses EJSON.parse instead of JSON.parse to handle BSON types @@ -55,12 +53,8 @@ export function createStdioTransport(): StdioServerTransport { export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; - constructor( - userConfig: UserConfig, - createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager, - additionalLoggers: LoggerBase[] = [] - ) { - super(userConfig, createConnectionManager, additionalLoggers); + constructor(config: TransportRunnerConfig) { + super(config); } async start(): Promise { diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index 4e8aebb8e..ad04ec732 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -3,11 +3,9 @@ import type http from "http"; import { randomUUID } from "crypto"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; -import { LogId, type LoggerBase } from "../common/logger.js"; -import { type UserConfig } from "../common/config.js"; +import { LogId } from "../common/logger.js"; import { SessionStore } from "../common/sessionStore.js"; -import { TransportRunnerBase } from "./base.js"; -import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; +import { TransportRunnerBase, type TransportRunnerConfig } from "./base.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; @@ -19,12 +17,8 @@ export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; private sessionStore!: SessionStore; - constructor( - userConfig: UserConfig, - createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager, - additionalLoggers: LoggerBase[] = [] - ) { - super(userConfig, createConnectionManager, additionalLoggers); + constructor(config: TransportRunnerConfig) { + super(config); } public get serverAddress(): string { diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 4ad07c83c..8820476a7 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -13,6 +13,7 @@ import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest import type { ConnectionManager, ConnectionState } from "../../src/common/connectionManager.js"; import { MCPConnectionManager } from "../../src/common/connectionManager.js"; import { DeviceId } from "../../src/helpers/deviceId.js"; +import { connectionErrorHandler } from "../../src/common/connectionErrorHandler.js"; interface ParameterInfo { name: string; @@ -101,6 +102,7 @@ export function setupIntegrationTest( name: "test-server", version: "5.2.3", }), + connectionErrorHandler, }); await mcpServer.connect(serverTransport); diff --git a/tests/integration/tools/mongodb/mongodbTool.test.ts b/tests/integration/tools/mongodb/mongodbTool.test.ts new file mode 100644 index 000000000..7f8c76268 --- /dev/null +++ b/tests/integration/tools/mongodb/mongodbTool.test.ts @@ -0,0 +1,290 @@ +import { vi, it, describe, beforeEach, afterEach, type MockedFunction, afterAll, expect } from "vitest"; +import { type CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { MongoDBToolBase } from "../../../../src/tools/mongodb/mongodbTool.js"; +import { type OperationType } from "../../../../src/tools/tool.js"; +import { defaultDriverOptions, type UserConfig } from "../../../../src/common/config.js"; +import { MCPConnectionManager } from "../../../../src/common/connectionManager.js"; +import { Session } from "../../../../src/common/session.js"; +import { CompositeLogger } from "../../../../src/common/logger.js"; +import { DeviceId } from "../../../../src/helpers/deviceId.js"; +import { ExportsManager } from "../../../../src/common/exportsManager.js"; +import { InMemoryTransport } from "../../inMemoryTransport.js"; +import { Telemetry } from "../../../../src/telemetry/telemetry.js"; +import { Server } from "../../../../src/server.js"; +import { type ConnectionErrorHandler, connectionErrorHandler } from "../../../../src/common/connectionErrorHandler.js"; +import { defaultTestConfig } from "../../helpers.js"; +import { setupMongoDBIntegrationTest } from "./mongodbHelpers.js"; +import { ErrorCodes } from "../../../../src/common/errors.js"; + +const injectedErrorHandler: ConnectionErrorHandler = (error) => { + switch (error.code) { + case ErrorCodes.NotConnectedToMongoDB: + return { + errorHandled: true, + result: { + isError: true, + content: [ + { + type: "text", + text: "Custom handler - Not connected", + }, + ], + }, + }; + case ErrorCodes.MisconfiguredConnectionString: + return { + errorHandled: true, + result: { + isError: true, + content: [ + { + type: "text", + text: "Custom handler - Misconfigured", + }, + ], + }, + }; + } +}; + +describe("MongoDBTool implementations", () => { + const mdbIntegration = setupMongoDBIntegrationTest({ enterprise: false }, []); + const executeStub: MockedFunction<() => Promise> = vi + .fn() + .mockResolvedValue({ content: [{ type: "text", text: "Something" }] }); + class RandomTool extends MongoDBToolBase { + name = "Random"; + operationType: OperationType = "read"; + protected description = "This is a tool."; + protected argsShape = {}; + public async execute(): Promise { + await this.ensureConnected(); + return executeStub(); + } + } + + let tool: RandomTool | undefined; + let mcpClient: Client | undefined; + let mcpServer: Server | undefined; + let deviceId: DeviceId | undefined; + + async function cleanupAndStartServer( + config: Partial | undefined = {}, + errorHandler: ConnectionErrorHandler | undefined = connectionErrorHandler + ): Promise { + await cleanup(); + const userConfig: UserConfig = { ...defaultTestConfig, telemetry: "disabled", ...config }; + const driverOptions = defaultDriverOptions; + const logger = new CompositeLogger(); + const exportsManager = ExportsManager.init(userConfig, logger); + deviceId = DeviceId.create(logger); + const connectionManager = new MCPConnectionManager(userConfig, driverOptions, logger, deviceId); + const session = new Session({ + apiBaseUrl: userConfig.apiBaseUrl, + apiClientId: userConfig.apiClientId, + apiClientSecret: userConfig.apiClientSecret, + logger, + exportsManager, + connectionManager, + }); + const telemetry = Telemetry.create(session, userConfig, deviceId); + + const clientTransport = new InMemoryTransport(); + const serverTransport = new InMemoryTransport(); + + await serverTransport.start(); + await clientTransport.start(); + + void clientTransport.output.pipeTo(serverTransport.input); + void serverTransport.output.pipeTo(clientTransport.input); + + mcpClient = new Client( + { + name: "test-client", + version: "1.2.3", + }, + { + capabilities: {}, + } + ); + + mcpServer = new Server({ + session, + userConfig, + telemetry, + mcpServer: new McpServer({ + name: "test-server", + version: "5.2.3", + }), + connectionErrorHandler: errorHandler, + }); + + tool = new RandomTool(session, userConfig, telemetry); + tool.register(mcpServer); + + await mcpServer.connect(serverTransport); + await mcpClient.connect(clientTransport); + } + + async function cleanup(): Promise { + await mcpServer?.session.disconnect(); + await mcpClient?.close(); + mcpClient = undefined; + + await mcpServer?.close(); + mcpServer = undefined; + + deviceId?.close(); + deviceId = undefined; + + tool = undefined; + } + + beforeEach(async () => { + await cleanupAndStartServer(); + }); + + afterEach(async () => { + vi.clearAllMocks(); + if (mcpServer) { + await mcpServer.session.disconnect(); + } + }); + + afterAll(cleanup); + + describe("when MCP is using default connection error handler", () => { + describe("and comes across a MongoDB Error - NotConnectedToMongoDB", () => { + it("should handle the error", async () => { + const toolResponse = await mcpClient?.callTool({ + name: "Random", + arguments: {}, + }); + expect(toolResponse?.isError).to.equal(true); + expect(toolResponse?.content).toEqual( + expect.arrayContaining([ + { + type: "text", + text: "You need to connect to a MongoDB instance before you can access its data.", + }, + ]) + ); + }); + }); + + describe("and comes across a MongoDB Error - MisconfiguredConnectionString", () => { + it("should handle the error", async () => { + // This is a misconfigured connection string + await cleanupAndStartServer({ connectionString: "mongodb://localhost:1234" }); + const toolResponse = await mcpClient?.callTool({ + name: "Random", + arguments: {}, + }); + expect(toolResponse?.isError).to.equal(true); + expect(toolResponse?.content).toEqual( + expect.arrayContaining([ + { + type: "text", + text: "The configured connection string is not valid. Please check the connection string and confirm it points to a valid MongoDB instance.", + }, + ]) + ); + }); + }); + + describe("and comes across any other error MongoDB Error - ForbiddenCollscan", () => { + it("should not handle the error and let the static handling take over it", async () => { + // This is a misconfigured connection string + await cleanupAndStartServer({ connectionString: mdbIntegration.connectionString(), indexCheck: true }); + const toolResponse = await mcpClient?.callTool({ + name: "find", + arguments: { + database: "db1", + collection: "coll1", + }, + }); + expect(toolResponse?.isError).to.equal(true); + expect(toolResponse?.content).toEqual( + expect.arrayContaining([ + { + type: "text", + text: "Index check failed: The find operation on \"db1.coll1\" performs a collection scan (COLLSCAN) instead of using an index. Consider adding an index for better performance. Use 'explain' tool for query plan analysis or 'collection-indexes' to view existing indexes. To disable this check, set MDB_MCP_INDEX_CHECK to false.", + }, + ]) + ); + }); + }); + }); + + describe("when MCP is using injected connection error handler", () => { + beforeEach(async () => { + await cleanupAndStartServer(defaultTestConfig, injectedErrorHandler); + }); + + describe("and comes across a MongoDB Error - NotConnectedToMongoDB", () => { + it("should handle the error", async () => { + const toolResponse = await mcpClient?.callTool({ + name: "Random", + arguments: {}, + }); + expect(toolResponse?.isError).to.equal(true); + expect(toolResponse?.content).toEqual( + expect.arrayContaining([ + { + type: "text", + text: "Custom handler - Not connected", + }, + ]) + ); + }); + }); + + describe("and comes across a MongoDB Error - MisconfiguredConnectionString", () => { + it("should handle the error", async () => { + // This is a misconfigured connection string + await cleanupAndStartServer({ connectionString: "mongodb://localhost:1234" }, injectedErrorHandler); + const toolResponse = await mcpClient?.callTool({ + name: "Random", + arguments: {}, + }); + expect(toolResponse?.isError).to.equal(true); + expect(toolResponse?.content).toEqual( + expect.arrayContaining([ + { + type: "text", + text: "Custom handler - Misconfigured", + }, + ]) + ); + }); + }); + + describe("and comes across any other error MongoDB Error - ForbiddenCollscan", () => { + it("should not handle the error and let the static handling take over it", async () => { + // This is a misconfigured connection string + await cleanupAndStartServer( + { connectionString: mdbIntegration.connectionString(), indexCheck: true }, + injectedErrorHandler + ); + const toolResponse = await mcpClient?.callTool({ + name: "find", + arguments: { + database: "db1", + collection: "coll1", + }, + }); + expect(toolResponse?.isError).to.equal(true); + expect(toolResponse?.content).toEqual( + expect.arrayContaining([ + { + type: "text", + text: "Index check failed: The find operation on \"db1.coll1\" performs a collection scan (COLLSCAN) instead of using an index. Consider adding an index for better performance. Use 'explain' tool for query plan analysis or 'collection-indexes' to view existing indexes. To disable this check, set MDB_MCP_INDEX_CHECK to false.", + }, + ]) + ); + }); + }); + }); +}); diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index 462ba9330..6a7b17bff 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -29,7 +29,7 @@ describe("StreamableHttpRunner", () => { describe(description, () => { beforeAll(async () => { config.httpHeaders = headers; - runner = new StreamableHttpRunner(config); + runner = new StreamableHttpRunner({ userConfig: config }); await runner.start(); }); @@ -110,7 +110,7 @@ describe("StreamableHttpRunner", () => { try { for (let i = 0; i < 3; i++) { config.httpPort = 0; // Use a random port for each runner - const runner = new StreamableHttpRunner(config); + const runner = new StreamableHttpRunner({ userConfig: config }); await runner.start(); runners.push(runner); } @@ -139,7 +139,11 @@ describe("StreamableHttpRunner", () => { it("can provide custom logger", async () => { const logger = new CustomLogger(); - const runner = new StreamableHttpRunner(config, createMCPConnectionManager, [logger]); + const runner = new StreamableHttpRunner({ + userConfig: config, + createConnectionManager: createMCPConnectionManager, + additionalLoggers: [logger], + }); await runner.start(); const messages = logger.messages; From b6072318058d475cb9dca943b708f14789afa74e Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Tue, 2 Sep 2025 13:23:32 +0200 Subject: [PATCH 2/5] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/common/connectionErrorHandler.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/connectionErrorHandler.ts b/src/common/connectionErrorHandler.ts index e1ddce66a..219f01d6f 100644 --- a/src/common/connectionErrorHandler.ts +++ b/src/common/connectionErrorHandler.ts @@ -14,7 +14,7 @@ export type ConnectionErrorHandler = ( export const connectionErrorHandler: ConnectionErrorHandler = (error, { availableTools, connectionState }) => { const connectTools = availableTools .filter((t) => t.operationType === "connect") - .sort((a, b) => a.category.localeCompare(b.category)); // Sort Altas tools before MongoDB tools + .sort((a, b) => a.category.localeCompare(b.category)); // Sort Atlas tools before MongoDB tools // Find the first Atlas connect tool if available and suggest to the LLM to use it. // Note: if we ever have multiple Atlas connect tools, we may want to refine this logic to select the most appropriate one. From c83c4dbfae32f5abb0b25dec5682728f323dc935 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Tue, 2 Sep 2025 13:55:48 +0200 Subject: [PATCH 3/5] chore: expose ConnectionErrorHandler type through lib --- src/lib.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.ts b/src/lib.ts index 01dc8b887..abbadc0a9 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -11,4 +11,5 @@ export { type ConnectionStateErrored, type ConnectionManagerFactoryFn, } from "./common/connectionManager.js"; +export { type ConnectionErrorHandler } from "./common/connectionErrorHandler.js"; export { Telemetry } from "./telemetry/telemetry.js"; From 2b725e78fc6c0cd52dcd77983f386a4d5298126c Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Tue, 2 Sep 2025 21:47:05 +0200 Subject: [PATCH 4/5] chore: export types for lib --- src/common/connectionErrorHandler.ts | 11 ++++++----- src/lib.ts | 8 +++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/common/connectionErrorHandler.ts b/src/common/connectionErrorHandler.ts index 219f01d6f..9de63befe 100644 --- a/src/common/connectionErrorHandler.ts +++ b/src/common/connectionErrorHandler.ts @@ -5,11 +5,12 @@ import type { ToolBase } from "../tools/tool.js"; export type ConnectionErrorHandler = ( error: MongoDBError, - additionalContext: { - availableTools: ToolBase[]; - connectionState: AnyConnectionState; - } -) => { errorHandled: false } | { errorHandled: true; result: CallToolResult }; + additionalContext: ConnectionErrorHandlerContext +) => ConnectionErrorUnhandled | ConnectionErrorHandled; + +export type ConnectionErrorHandlerContext = { availableTools: ToolBase[]; connectionState: AnyConnectionState }; +export type ConnectionErrorUnhandled = { errorHandled: false }; +export type ConnectionErrorHandled = { errorHandled: true; result: CallToolResult }; export const connectionErrorHandler: ConnectionErrorHandler = (error, { availableTools, connectionState }) => { const connectTools = availableTools diff --git a/src/lib.ts b/src/lib.ts index abbadc0a9..e6388b511 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -11,5 +11,11 @@ export { type ConnectionStateErrored, type ConnectionManagerFactoryFn, } from "./common/connectionManager.js"; -export { type ConnectionErrorHandler } from "./common/connectionErrorHandler.js"; +export type { + ConnectionErrorHandler, + ConnectionErrorHandled, + ConnectionErrorUnhandled, + ConnectionErrorHandlerContext, +} from "./common/connectionErrorHandler.js"; +export { ErrorCodes } from "./common/errors.js"; export { Telemetry } from "./telemetry/telemetry.js"; From ed6adc6d6ddeed62f13d1b33b98cbb513d30a2b9 Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Wed, 3 Sep 2025 11:59:35 +0200 Subject: [PATCH 5/5] chore: PR feedback --- src/tools/mongodb/mongodbTool.ts | 4 ++-- src/transports/base.ts | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index fd1981818..ded994ab3 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -68,9 +68,9 @@ export abstract class MongoDBToolBase extends ToolBase { }); if (outcome?.errorHandled) { return outcome.result; - } else { - return super.handleError(error, args); } + + return super.handleError(error, args); } case ErrorCodes.ForbiddenCollscan: return { diff --git a/src/transports/base.ts b/src/transports/base.ts index 6fe198c21..f870514b9 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -30,13 +30,13 @@ export abstract class TransportRunnerBase { protected constructor({ userConfig, - createConnectionManager, - connectionErrorHandler, + createConnectionManager = createMCPConnectionManager, + connectionErrorHandler = defaultConnectionErrorHandler, additionalLoggers = [], }: TransportRunnerConfig) { this.userConfig = userConfig; - this.createConnectionManager = createConnectionManager ?? createMCPConnectionManager; - this.connectionErrorHandler = connectionErrorHandler ?? defaultConnectionErrorHandler; + this.createConnectionManager = createConnectionManager; + this.connectionErrorHandler = connectionErrorHandler; const loggers: LoggerBase[] = [...additionalLoggers]; if (this.userConfig.loggers.includes("stderr")) { loggers.push(new ConsoleLogger());