Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/common/connectionErrorHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire file is a cut of existing logic in MongodbToolBase.

import { ErrorCodes, type MongoDBError } from "./errors.js";
import type { AnyConnectionState } from "./connectionManager.js";
import type { ToolBase } from "../tools/tool.js";

export type ConnectionErrorHandler = (
error: MongoDBError<ErrorCodes.NotConnectedToMongoDB | ErrorCodes.MisconfiguredConnectionString>,
additionalContext: ConnectionErrorHandlerContext
) => ConnectionErrorUnhandled | ConnectionErrorHandled;

export type ConnectionErrorHandlerContext = { availableTools: ToolBase[]; connectionState: AnyConnectionState };
export type ConnectionErrorUnhandled = { errorHandled: false };
export type ConnectionErrorHandled = { errorHandled: true; result: CallToolResult };

export const connectionErrorHandler: ConnectionErrorHandler = (error, { availableTools, connectionState }) => {
const connectTools = availableTools
.filter((t) => t.operationType === "connect")
.sort((a, b) => a.category.localeCompare(b.category)); // Sort Atlas tools before MongoDB tools

// Find the first Atlas connect tool if available and suggest to the LLM to use it.
// Note: if we ever have multiple Atlas connect tools, we may want to refine this logic to select the most appropriate one.
const atlasConnectTool = connectTools?.find((t) => t.category === "atlas");
const llmConnectHint = atlasConnectTool
? `Note to LLM: prefer using the "${atlasConnectTool.name}" tool to connect to an Atlas cluster over using a connection string. Make sure to ask the user to specify a cluster name they want to connect to or ask them if they want to use the "list-clusters" tool to list all their clusters. Do not invent cluster names or connection strings unless the user has explicitly specified them. If they've previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same cluster/connection.`
: "Note to LLM: do not invent connection strings and explicitly ask the user to provide one. If they have previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same connection string.";

const connectToolsNames = connectTools?.map((t) => `"${t.name}"`).join(", ");
const additionalPromptForConnectivity: { type: "text"; text: string }[] = [];

if (connectionState.tag === "connecting" && connectionState.oidcConnectionType) {
additionalPromptForConnectivity.push({
type: "text",
text: `The user needs to finish their OIDC connection by opening '${connectionState.oidcLoginUrl}' in the browser and use the following user code: '${connectionState.oidcUserCode}'`,
});
} else {
additionalPromptForConnectivity.push({
type: "text",
text: connectToolsNames
? `Please use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance or update the MCP server configuration to include a connection string. ${llmConnectHint}`
: "There are no tools available to connect. Please update the configuration to include a connection string and restart the server.",
});
}

switch (error.code) {
case ErrorCodes.NotConnectedToMongoDB:
return {
errorHandled: true,
result: {
content: [
{
type: "text",
text: "You need to connect to a MongoDB instance before you can access its data.",
},
...additionalPromptForConnectivity,
],
isError: true,
},
};
case ErrorCodes.MisconfiguredConnectionString:
return {
errorHandled: true,
result: {
content: [
{
type: "text",
text: "The configured connection string is not valid. Please check the connection string and confirm it points to a valid MongoDB instance.",
},
{
type: "text",
text: connectTools
? `Alternatively, you can use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance. ${llmConnectHint}`
: "Please update the configuration to use a valid connection string and restart the server.",
},
],
isError: true,
},
};

default:
return { errorHandled: false };
}
};
4 changes: 2 additions & 2 deletions src/common/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ export enum ErrorCodes {
ForbiddenCollscan = 1_000_002,
}

export class MongoDBError extends Error {
export class MongoDBError<ErrorCode extends ErrorCodes = ErrorCodes> extends Error {
constructor(
public code: ErrorCodes,
public code: ErrorCode,
message: string
) {
super(message);
Expand Down
9 changes: 8 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ async function main(): Promise<void> {
assertHelpMode();
assertVersionMode();

const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config);
const transportRunner =
config.transport === "stdio"
? new StdioRunner({
userConfig: config,
})
: new StreamableHttpRunner({
userConfig: config,
});
const shutdown = (): void => {
transportRunner.logger.info({
id: LogId.serverCloseRequested,
Expand Down
7 changes: 7 additions & 0 deletions src/lib.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,11 @@ export {
type ConnectionStateErrored,
type ConnectionManagerFactoryFn,
} from "./common/connectionManager.js";
export type {
ConnectionErrorHandler,
ConnectionErrorHandled,
ConnectionErrorUnhandled,
ConnectionErrorHandlerContext,
} from "./common/connectionErrorHandler.js";
export { ErrorCodes } from "./common/errors.js";
export { Telemetry } from "./telemetry/telemetry.js";
6 changes: 5 additions & 1 deletion src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import assert from "assert";
import type { ToolBase } from "./tools/tool.js";
import { validateConnectionString } from "./helpers/connectionOptions.js";
import { packageInfo } from "./common/packageInfo.js";
import { type ConnectionErrorHandler } from "./common/connectionErrorHandler.js";

export interface ServerOptions {
session: Session;
userConfig: UserConfig;
mcpServer: McpServer;
telemetry: Telemetry;
connectionErrorHandler: ConnectionErrorHandler;
}

export class Server {
Expand All @@ -35,6 +37,7 @@ export class Server {
private readonly telemetry: Telemetry;
public readonly userConfig: UserConfig;
public readonly tools: ToolBase[] = [];
public readonly connectionErrorHandler: ConnectionErrorHandler;

private _mcpLogLevel: LogLevel = "debug";

Expand All @@ -45,12 +48,13 @@ export class Server {
private readonly startTime: number;
private readonly subscriptions = new Set<string>();

constructor({ session, mcpServer, userConfig, telemetry }: ServerOptions) {
constructor({ session, mcpServer, userConfig, telemetry, connectionErrorHandler }: ServerOptions) {
this.startTime = Date.now();
this.session = session;
this.telemetry = telemetry;
this.mcpServer = mcpServer;
this.userConfig = userConfig;
this.connectionErrorHandler = connectionErrorHandler;
}

async connect(transport: Transport): Promise<void> {
Expand Down
69 changes: 14 additions & 55 deletions src/tools/mongodb/mongodbTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,63 +56,22 @@ export abstract class MongoDBToolBase extends ToolBase {
args: ToolArgs<typeof this.argsShape>
): Promise<CallToolResult> | CallToolResult {
if (error instanceof MongoDBError) {
const connectTools = this.server?.tools
.filter((t) => t.operationType === "connect")
.sort((a, b) => a.category.localeCompare(b.category)); // Sort Altas tools before MongoDB tools

// Find the first Atlas connect tool if available and suggest to the LLM to use it.
// Note: if we ever have multiple Atlas connect tools, we may want to refine this logic to select the most appropriate one.
const atlasConnectTool = connectTools?.find((t) => t.category === "atlas");
const llmConnectHint = atlasConnectTool
? `Note to LLM: prefer using the "${atlasConnectTool.name}" tool to connect to an Atlas cluster over using a connection string. Make sure to ask the user to specify a cluster name they want to connect to or ask them if they want to use the "list-clusters" tool to list all their clusters. Do not invent cluster names or connection strings unless the user has explicitly specified them. If they've previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same cluster/connection.`
: "Note to LLM: do not invent connection strings and explicitly ask the user to provide one. If they have previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same connection string.";

const connectToolsNames = connectTools?.map((t) => `"${t.name}"`).join(", ");
const connectionStatus = this.session.connectionManager.currentConnectionState;
const additionalPromptForConnectivity: { type: "text"; text: string }[] = [];

if (connectionStatus.tag === "connecting" && connectionStatus.oidcConnectionType) {
additionalPromptForConnectivity.push({
type: "text",
text: `The user needs to finish their OIDC connection by opening '${connectionStatus.oidcLoginUrl}' in the browser and use the following user code: '${connectionStatus.oidcUserCode}'`,
});
} else {
additionalPromptForConnectivity.push({
type: "text",
text: connectToolsNames
? `Please use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance or update the MCP server configuration to include a connection string. ${llmConnectHint}`
: "There are no tools available to connect. Please update the configuration to include a connection string and restart the server.",
});
}

switch (error.code) {
case ErrorCodes.NotConnectedToMongoDB:
return {
content: [
{
type: "text",
text: "You need to connect to a MongoDB instance before you can access its data.",
},
...additionalPromptForConnectivity,
],
isError: true,
};
case ErrorCodes.MisconfiguredConnectionString:
return {
content: [
{
type: "text",
text: "The configured connection string is not valid. Please check the connection string and confirm it points to a valid MongoDB instance.",
},
{
type: "text",
text: connectTools
? `Alternatively, you can use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance. ${llmConnectHint}`
: "Please update the configuration to use a valid connection string and restart the server.",
},
],
isError: true,
};
case ErrorCodes.MisconfiguredConnectionString: {
const connectionError = error as MongoDBError<
ErrorCodes.NotConnectedToMongoDB | ErrorCodes.MisconfiguredConnectionString
>;
const outcome = this.server?.connectionErrorHandler(connectionError, {
availableTools: this.server?.tools ?? [],
connectionState: this.session.connectionManager.currentConnectionState,
});
if (outcome?.errorHandled) {
return outcome.result;
} else {
return super.handleError(error, args);
}
}
case ErrorCodes.ForbiddenCollscan:
return {
content: [
Expand Down
31 changes: 25 additions & 6 deletions src/transports/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,35 @@ import type { LoggerBase } from "../common/logger.js";
import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js";
import { ExportsManager } from "../common/exportsManager.js";
import { DeviceId } from "../helpers/deviceId.js";
import { type ConnectionManagerFactoryFn } from "../common/connectionManager.js";
import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js";
import {
type ConnectionErrorHandler,
connectionErrorHandler as defaultConnectionErrorHandler,
} from "../common/connectionErrorHandler.js";

export type TransportRunnerConfig = {
userConfig: UserConfig;
createConnectionManager?: ConnectionManagerFactoryFn;
connectionErrorHandler?: ConnectionErrorHandler;
additionalLoggers?: LoggerBase[];
};

export abstract class TransportRunnerBase {
public logger: LoggerBase;
public deviceId: DeviceId;
protected readonly userConfig: UserConfig;
private readonly createConnectionManager: ConnectionManagerFactoryFn;
private readonly connectionErrorHandler: ConnectionErrorHandler;

protected constructor(
protected readonly userConfig: UserConfig,
private readonly createConnectionManager: ConnectionManagerFactoryFn,
additionalLoggers: LoggerBase[]
) {
protected constructor({
userConfig,
createConnectionManager,
connectionErrorHandler,
additionalLoggers = [],
}: TransportRunnerConfig) {
this.userConfig = userConfig;
this.createConnectionManager = createConnectionManager ?? createMCPConnectionManager;
this.connectionErrorHandler = connectionErrorHandler ?? defaultConnectionErrorHandler;
const loggers: LoggerBase[] = [...additionalLoggers];
if (this.userConfig.loggers.includes("stderr")) {
loggers.push(new ConsoleLogger());
Expand Down Expand Up @@ -68,6 +86,7 @@ export abstract class TransportRunnerBase {
session,
telemetry,
userConfig: this.userConfig,
connectionErrorHandler: this.connectionErrorHandler,
});

// We need to create the MCP logger after the server is constructed
Expand Down
14 changes: 4 additions & 10 deletions src/transports/stdio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ import { EJSON } from "bson";
import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js";
import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js";
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
import { type LoggerBase, LogId } from "../common/logger.js";
import { LogId } from "../common/logger.js";
import type { Server } from "../server.js";
import { TransportRunnerBase } from "./base.js";
import { type UserConfig } from "../common/config.js";
import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js";
import { TransportRunnerBase, type TransportRunnerConfig } from "./base.js";

// This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk
// but it uses EJSON.parse instead of JSON.parse to handle BSON types
Expand Down Expand Up @@ -55,12 +53,8 @@ export function createStdioTransport(): StdioServerTransport {
export class StdioRunner extends TransportRunnerBase {
private server: Server | undefined;

constructor(
userConfig: UserConfig,
createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager,
additionalLoggers: LoggerBase[] = []
) {
super(userConfig, createConnectionManager, additionalLoggers);
constructor(config: TransportRunnerConfig) {
super(config);
}

async start(): Promise<void> {
Expand Down
14 changes: 4 additions & 10 deletions src/transports/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ import type http from "http";
import { randomUUID } from "crypto";
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js";
import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js";
import { LogId, type LoggerBase } from "../common/logger.js";
import { type UserConfig } from "../common/config.js";
import { LogId } from "../common/logger.js";
import { SessionStore } from "../common/sessionStore.js";
import { TransportRunnerBase } from "./base.js";
import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js";
import { TransportRunnerBase, type TransportRunnerConfig } from "./base.js";

const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000;
const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001;
Expand All @@ -19,12 +17,8 @@ export class StreamableHttpRunner extends TransportRunnerBase {
private httpServer: http.Server | undefined;
private sessionStore!: SessionStore;

constructor(
userConfig: UserConfig,
createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager,
additionalLoggers: LoggerBase[] = []
) {
super(userConfig, createConnectionManager, additionalLoggers);
constructor(config: TransportRunnerConfig) {
super(config);
}

public get serverAddress(): string {
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest
import type { ConnectionManager, ConnectionState } from "../../src/common/connectionManager.js";
import { MCPConnectionManager } from "../../src/common/connectionManager.js";
import { DeviceId } from "../../src/helpers/deviceId.js";
import { connectionErrorHandler } from "../../src/common/connectionErrorHandler.js";

interface ParameterInfo {
name: string;
Expand Down Expand Up @@ -101,6 +102,7 @@ export function setupIntegrationTest(
name: "test-server",
version: "5.2.3",
}),
connectionErrorHandler,
});

await mcpServer.connect(serverTransport);
Expand Down
Loading
Loading