diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts new file mode 100644 index 00000000..db33b21b --- /dev/null +++ b/src/common/connectionManager.ts @@ -0,0 +1,192 @@ +import { ConnectOptions } from "./config.js"; +import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; +import EventEmitter from "events"; +import { setAppNameParamIfMissing } from "../helpers/connectionOptions.js"; +import { packageInfo } from "./packageInfo.js"; +import ConnectionString from "mongodb-connection-string-url"; +import { MongoClientOptions } from "mongodb"; +import { ErrorCodes, MongoDBError } from "./errors.js"; + +export interface AtlasClusterConnectionInfo { + username: string; + projectId: string; + clusterName: string; + expiryDate: Date; +} + +export interface ConnectionSettings extends ConnectOptions { + connectionString: string; + atlas?: AtlasClusterConnectionInfo; +} + +type ConnectionTag = "connected" | "connecting" | "disconnected" | "errored"; +type OIDCConnectionAuthType = "oidc-auth-flow" | "oidc-device-flow"; +export type ConnectionStringAuthType = "scram" | "ldap" | "kerberos" | OIDCConnectionAuthType | "x.509"; + +export interface ConnectionState { + tag: ConnectionTag; + connectionStringAuthType?: ConnectionStringAuthType; + connectedAtlasCluster?: AtlasClusterConnectionInfo; +} + +export interface ConnectionStateConnected extends ConnectionState { + tag: "connected"; + serviceProvider: NodeDriverServiceProvider; +} + +export interface ConnectionStateConnecting extends ConnectionState { + tag: "connecting"; + serviceProvider: NodeDriverServiceProvider; + oidcConnectionType: OIDCConnectionAuthType; + oidcLoginUrl?: string; + oidcUserCode?: string; +} + +export interface ConnectionStateDisconnected extends ConnectionState { + tag: "disconnected"; +} + +export interface ConnectionStateErrored extends ConnectionState { + tag: "errored"; + errorReason: string; +} + +export type AnyConnectionState = + | ConnectionStateConnected + | ConnectionStateConnecting + | ConnectionStateDisconnected + | ConnectionStateErrored; + +export interface ConnectionManagerEvents { + "connection-requested": [AnyConnectionState]; + "connection-succeeded": [ConnectionStateConnected]; + "connection-timed-out": [ConnectionStateErrored]; + "connection-closed": [ConnectionStateDisconnected]; + "connection-errored": [ConnectionStateErrored]; +} + +export class ConnectionManager extends EventEmitter { + private state: AnyConnectionState; + + constructor() { + super(); + this.state = { tag: "disconnected" }; + } + + async connect(settings: ConnectionSettings): Promise { + this.emit("connection-requested", this.state); + + if (this.state.tag === "connected" || this.state.tag === "connecting") { + await this.disconnect(); + } + + let serviceProvider: NodeDriverServiceProvider; + try { + settings = { ...settings }; + settings.connectionString = setAppNameParamIfMissing({ + connectionString: settings.connectionString, + defaultAppName: `${packageInfo.mcpServerName} ${packageInfo.version}`, + }); + + serviceProvider = await NodeDriverServiceProvider.connect(settings.connectionString, { + productDocsLink: "https://github.com/mongodb-js/mongodb-mcp-server/", + productName: "MongoDB MCP", + readConcern: { + level: settings.readConcern, + }, + readPreference: settings.readPreference, + writeConcern: { + w: settings.writeConcern, + }, + timeoutMS: settings.timeoutMS, + proxy: { useEnvironmentVariableProxies: true }, + applyProxyToOIDC: true, + }); + } catch (error: unknown) { + const errorReason = error instanceof Error ? error.message : `${error as string}`; + this.changeState("connection-errored", { + tag: "errored", + errorReason, + connectedAtlasCluster: settings.atlas, + }); + throw new MongoDBError(ErrorCodes.MisconfiguredConnectionString, errorReason); + } + + try { + await serviceProvider?.runCommand?.("admin", { hello: 1 }); + + return this.changeState("connection-succeeded", { + tag: "connected", + connectedAtlasCluster: settings.atlas, + serviceProvider, + connectionStringAuthType: ConnectionManager.inferConnectionTypeFromSettings(settings), + }); + } catch (error: unknown) { + const errorReason = error instanceof Error ? error.message : `${error as string}`; + this.changeState("connection-errored", { + tag: "errored", + errorReason, + connectedAtlasCluster: settings.atlas, + }); + throw new MongoDBError(ErrorCodes.NotConnectedToMongoDB, errorReason); + } + } + + async disconnect(): Promise { + if (this.state.tag === "disconnected" || this.state.tag === "errored") { + return this.state; + } + + if (this.state.tag === "connected" || this.state.tag === "connecting") { + try { + await this.state.serviceProvider?.close(true); + } finally { + this.changeState("connection-closed", { + tag: "disconnected", + }); + } + } + + return { tag: "disconnected" }; + } + + get currentConnectionState(): AnyConnectionState { + return this.state; + } + + changeState( + event: Event, + newState: State + ): State { + this.state = newState; + // TypeScript doesn't seem to be happy with the spread operator and generics + // eslint-disable-next-line + this.emit(event, ...([newState] as any)); + return newState; + } + + static inferConnectionTypeFromSettings(settings: ConnectionSettings): ConnectionStringAuthType { + const connString = new ConnectionString(settings.connectionString); + const searchParams = connString.typedSearchParams(); + + switch (searchParams.get("authMechanism")) { + case "MONGODB-OIDC": { + return "oidc-auth-flow"; // TODO: depending on if we don't have a --browser later it can be oidc-device-flow + } + case "MONGODB-X509": + return "x.509"; + case "GSSAPI": + return "kerberos"; + case "PLAIN": + if (searchParams.get("authSource") === "$external") { + return "ldap"; + } + return "scram"; + // default should catch also null, but eslint complains + // about it. + case null: + default: + return "scram"; + } + } +} diff --git a/src/common/session.ts b/src/common/session.ts index 2a75af33..0baccc9b 100644 --- a/src/common/session.ts +++ b/src/common/session.ts @@ -1,16 +1,21 @@ -import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { ApiClient, ApiClientCredentials } from "./atlas/apiClient.js"; import { Implementation } from "@modelcontextprotocol/sdk/types.js"; import logger, { LogId } from "./logger.js"; import EventEmitter from "events"; -import { ConnectOptions } from "./config.js"; -import { setAppNameParamIfMissing } from "../helpers/connectionOptions.js"; -import { packageInfo } from "./packageInfo.js"; +import { + AtlasClusterConnectionInfo, + ConnectionManager, + ConnectionSettings, + ConnectionStateConnected, +} from "./connectionManager.js"; +import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; +import { ErrorCodes, MongoDBError } from "./errors.js"; export interface SessionOptions { apiBaseUrl: string; apiClientId?: string; apiClientSecret?: string; + connectionManager?: ConnectionManager; } export type SessionEvents = { @@ -22,20 +27,14 @@ export type SessionEvents = { export class Session extends EventEmitter { sessionId?: string; - serviceProvider?: NodeDriverServiceProvider; + connectionManager: ConnectionManager; apiClient: ApiClient; agentRunner?: { name: string; version: string; }; - connectedAtlasCluster?: { - username: string; - projectId: string; - clusterName: string; - expiryDate: Date; - }; - constructor({ apiBaseUrl, apiClientId, apiClientSecret }: SessionOptions) { + constructor({ apiBaseUrl, apiClientId, apiClientSecret, connectionManager }: SessionOptions) { super(); const credentials: ApiClientCredentials | undefined = @@ -46,10 +45,13 @@ export class Session extends EventEmitter { } : undefined; - this.apiClient = new ApiClient({ - baseUrl: apiBaseUrl, - credentials, - }); + this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }); + + this.connectionManager = connectionManager ?? new ConnectionManager(); + this.connectionManager.on("connection-succeeded", () => this.emit("connect")); + this.connectionManager.on("connection-timed-out", (error) => this.emit("connection-error", error.errorReason)); + this.connectionManager.on("connection-closed", () => this.emit("disconnect")); + this.connectionManager.on("connection-errored", (error) => this.emit("connection-error", error.errorReason)); } setAgentRunner(agentRunner: Implementation | undefined) { @@ -62,22 +64,22 @@ export class Session extends EventEmitter { } async disconnect(): Promise { - if (this.serviceProvider) { - try { - await this.serviceProvider.close(true); - } catch (err: unknown) { - const error = err instanceof Error ? err : new Error(String(err)); - logger.error(LogId.mongodbDisconnectFailure, "Error closing service provider:", error.message); - } - this.serviceProvider = undefined; + const atlasCluster = this.connectedAtlasCluster; + + try { + await this.connectionManager.disconnect(); + } catch (err: unknown) { + const error = err instanceof Error ? err : new Error(String(err)); + logger.error(LogId.mongodbDisconnectFailure, "Error closing service provider:", error.message); } - if (this.connectedAtlasCluster?.username && this.connectedAtlasCluster?.projectId) { + + if (atlasCluster?.username && atlasCluster?.projectId) { void this.apiClient .deleteDatabaseUser({ params: { path: { - groupId: this.connectedAtlasCluster.projectId, - username: this.connectedAtlasCluster.username, + groupId: atlasCluster.projectId, + username: atlasCluster.username, databaseName: "admin", }, }, @@ -90,9 +92,7 @@ export class Session extends EventEmitter { `Error deleting previous database user: ${error.message}` ); }); - this.connectedAtlasCluster = undefined; } - this.emit("disconnect"); } async close(): Promise { @@ -101,35 +101,30 @@ export class Session extends EventEmitter { this.emit("close"); } - async connectToMongoDB(connectionString: string, connectOptions: ConnectOptions): Promise { - connectionString = setAppNameParamIfMissing({ - connectionString, - defaultAppName: `${packageInfo.mcpServerName} ${packageInfo.version}`, - }); - + async connectToMongoDB(settings: ConnectionSettings): Promise { try { - this.serviceProvider = await NodeDriverServiceProvider.connect(connectionString, { - productDocsLink: "https://github.com/mongodb-js/mongodb-mcp-server/", - productName: "MongoDB MCP", - readConcern: { - level: connectOptions.readConcern, - }, - readPreference: connectOptions.readPreference, - writeConcern: { - w: connectOptions.writeConcern, - }, - timeoutMS: connectOptions.timeoutMS, - proxy: { useEnvironmentVariableProxies: true }, - applyProxyToOIDC: true, - }); - - await this.serviceProvider?.runCommand?.("admin", { hello: 1 }); + await this.connectionManager.connect({ ...settings }); } catch (error: unknown) { - const message = error instanceof Error ? error.message : `${error as string}`; + const message = error instanceof Error ? error.message : (error as string); this.emit("connection-error", message); throw error; } + } + + get isConnectedToMongoDB(): boolean { + return this.connectionManager.currentConnectionState.tag === "connected"; + } + + get serviceProvider(): NodeDriverServiceProvider { + if (this.isConnectedToMongoDB) { + const state = this.connectionManager.currentConnectionState as ConnectionStateConnected; + return state.serviceProvider; + } + + throw new MongoDBError(ErrorCodes.NotConnectedToMongoDB, "Not connected to MongoDB"); + } - this.emit("connect"); + get connectedAtlasCluster(): AtlasClusterConnectionInfo | undefined { + return this.connectionManager.currentConnectionState.connectedAtlasCluster; } } diff --git a/src/resources/common/debug.ts b/src/resources/common/debug.ts index c8de2dd0..609b4b8e 100644 --- a/src/resources/common/debug.ts +++ b/src/resources/common/debug.ts @@ -10,10 +10,11 @@ type ConnectionStateDebuggingInformation = { export class DebugResource extends ReactiveResource( { - name: "debug-mongodb-connectivity", - uri: "debug://mongodb-connectivity", + name: "debug-mongodb", + uri: "debug://mongodb", config: { - description: "Debugging information for connectivity issues.", + description: + "Debugging information for MongoDB connectivity issues. Tracks the last connectivity error and attempt information.", }, }, { diff --git a/src/server.ts b/src/server.ts index 1eccbdcd..209bec02 100644 --- a/src/server.ts +++ b/src/server.ts @@ -38,12 +38,15 @@ export class Server { } async connect(transport: Transport): Promise { + // Resources are now reactive, so we register them ASAP so they can listen to events like + // connection events. + this.registerResources(); await this.validateConfig(); - this.mcpServer.server.registerCapabilities({ logging: {} }); + this.mcpServer.server.registerCapabilities({ logging: {}, resources: { subscribe: true, listChanged: true } }); + // TODO: Eventually we might want to make tools reactive too instead of relying on custom logic. this.registerTools(); - this.registerResources(); // This is a workaround for an issue we've seen with some models, where they'll see that everything in the `arguments` // object is optional, and then not pass it at all. However, the MCP server expects the `arguments` object to be if @@ -194,7 +197,10 @@ export class Server { if (this.userConfig.connectionString) { try { - await this.session.connectToMongoDB(this.userConfig.connectionString, this.userConfig.connectOptions); + await this.session.connectToMongoDB({ + connectionString: this.userConfig.connectionString, + ...this.userConfig.connectOptions, + }); } catch (error) { console.error( "Failed to connect to MongoDB instance using the connection string from the config: ", diff --git a/src/tools/atlas/connect/connectCluster.ts b/src/tools/atlas/connect/connectCluster.ts index a0087a0e..1af1aa3d 100644 --- a/src/tools/atlas/connect/connectCluster.ts +++ b/src/tools/atlas/connect/connectCluster.ts @@ -6,6 +6,7 @@ import { generateSecurePassword } from "../../../helpers/generatePassword.js"; import logger, { LogId } from "../../../common/logger.js"; import { inspectCluster } from "../../../common/atlas/cluster.js"; import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js"; +import { AtlasClusterConnectionInfo } from "../../../common/connectionManager.js"; const EXPIRY_MS = 1000 * 60 * 60 * 12; // 12 hours @@ -22,17 +23,18 @@ export class ConnectClusterTool extends AtlasToolBase { clusterName: z.string().describe("Atlas cluster name"), }; - private async queryConnection( + private queryConnection( projectId: string, clusterName: string - ): Promise<"connected" | "disconnected" | "connecting" | "connected-to-other-cluster" | "unknown"> { + ): "connected" | "disconnected" | "connecting" | "connected-to-other-cluster" | "unknown" { if (!this.session.connectedAtlasCluster) { - if (this.session.serviceProvider) { + if (this.session.isConnectedToMongoDB) { return "connected-to-other-cluster"; } return "disconnected"; } + const currentConectionState = this.session.connectionManager.currentConnectionState; if ( this.session.connectedAtlasCluster.projectId !== projectId || this.session.connectedAtlasCluster.clusterName !== clusterName @@ -40,28 +42,26 @@ export class ConnectClusterTool extends AtlasToolBase { return "connected-to-other-cluster"; } - if (!this.session.serviceProvider) { - return "connecting"; - } - - try { - await this.session.serviceProvider.runCommand("admin", { - ping: 1, - }); - - return "connected"; - } catch (err: unknown) { - const error = err instanceof Error ? err : new Error(String(err)); - logger.debug( - LogId.atlasConnectFailure, - "atlas-connect-cluster", - `error querying cluster: ${error.message}` - ); - return "unknown"; + switch (currentConectionState.tag) { + case "connecting": + case "disconnected": // we might still be calling Atlas APIs and not attempted yet to connect to MongoDB, but we are still "connecting" + return "connecting"; + case "connected": + return "connected"; + case "errored": + logger.debug( + LogId.atlasConnectFailure, + "atlas-connect-cluster", + `error querying cluster: ${currentConectionState.errorReason}` + ); + return "unknown"; } } - private async prepareClusterConnection(projectId: string, clusterName: string): Promise { + private async prepareClusterConnection( + projectId: string, + clusterName: string + ): Promise<{ connectionString: string; atlas: AtlasClusterConnectionInfo }> { const cluster = await inspectCluster(this.session.apiClient, projectId, clusterName); if (!cluster.connectionString) { @@ -109,7 +109,7 @@ export class ConnectClusterTool extends AtlasToolBase { }, }); - this.session.connectedAtlasCluster = { + const connectedAtlasCluster = { username, projectId, clusterName, @@ -120,10 +120,11 @@ export class ConnectClusterTool extends AtlasToolBase { cn.username = username; cn.password = password; cn.searchParams.set("authSource", "admin"); - return cn.toString(); + + return { connectionString: cn.toString(), atlas: connectedAtlasCluster }; } - private async connectToCluster(projectId: string, clusterName: string, connectionString: string): Promise { + private async connectToCluster(connectionString: string, atlas: AtlasClusterConnectionInfo): Promise { let lastError: Error | undefined = undefined; logger.debug( @@ -134,18 +135,10 @@ export class ConnectClusterTool extends AtlasToolBase { // try to connect for about 5 minutes for (let i = 0; i < 600; i++) { - if ( - !this.session.connectedAtlasCluster || - this.session.connectedAtlasCluster.projectId !== projectId || - this.session.connectedAtlasCluster.clusterName !== clusterName - ) { - throw new Error("Cluster connection aborted"); - } - try { lastError = undefined; - await this.session.connectToMongoDB(connectionString, this.config.connectOptions); + await this.session.connectToMongoDB({ connectionString, ...this.config.connectOptions, atlas }); break; } catch (err: unknown) { const error = err instanceof Error ? err : new Error(String(err)); @@ -160,12 +153,20 @@ export class ConnectClusterTool extends AtlasToolBase { await sleep(500); // wait for 500ms before retrying } + + if ( + !this.session.connectedAtlasCluster || + this.session.connectedAtlasCluster.projectId !== atlas.projectId || + this.session.connectedAtlasCluster.clusterName !== atlas.clusterName + ) { + throw new Error("Cluster connection aborted"); + } } if (lastError) { if ( - this.session.connectedAtlasCluster?.projectId === projectId && - this.session.connectedAtlasCluster?.clusterName === clusterName && + this.session.connectedAtlasCluster?.projectId === atlas.projectId && + this.session.connectedAtlasCluster?.clusterName === atlas.clusterName && this.session.connectedAtlasCluster?.username ) { void this.session.apiClient @@ -187,7 +188,6 @@ export class ConnectClusterTool extends AtlasToolBase { ); }); } - this.session.connectedAtlasCluster = undefined; throw lastError; } @@ -201,7 +201,7 @@ export class ConnectClusterTool extends AtlasToolBase { protected async execute({ projectId, clusterName }: ToolArgs): Promise { await ensureCurrentIpInAccessList(this.session.apiClient, projectId); for (let i = 0; i < 60; i++) { - const state = await this.queryConnection(projectId, clusterName); + const state = this.queryConnection(projectId, clusterName); switch (state) { case "connected": { return { @@ -221,10 +221,10 @@ export class ConnectClusterTool extends AtlasToolBase { case "disconnected": default: { await this.session.disconnect(); - const connectionString = await this.prepareClusterConnection(projectId, clusterName); + const { connectionString, atlas } = await this.prepareClusterConnection(projectId, clusterName); // try to connect for about 5 minutes asynchronously - void this.connectToCluster(projectId, clusterName, connectionString).catch((err: unknown) => { + void this.connectToCluster(connectionString, atlas).catch((err: unknown) => { const error = err instanceof Error ? err : new Error(String(err)); logger.error( LogId.atlasConnectFailure, diff --git a/src/tools/mongodb/connect/connect.ts b/src/tools/mongodb/connect/connect.ts index c2100689..1a1f8cd8 100644 --- a/src/tools/mongodb/connect/connect.ts +++ b/src/tools/mongodb/connect/connect.ts @@ -46,6 +46,10 @@ export class ConnectTool extends MongoDBToolBase { constructor(session: Session, config: UserConfig, telemetry: Telemetry) { super(session, config, telemetry); + session.on("connect", () => { + this.updateMetadata(); + }); + session.on("disconnect", () => { this.updateMetadata(); }); @@ -67,6 +71,7 @@ export class ConnectTool extends MongoDBToolBase { await this.connectToMongoDB(connectionString); this.updateMetadata(); + return { content: [{ type: "text", text: "Successfully connected to MongoDB." }], }; @@ -82,7 +87,7 @@ export class ConnectTool extends MongoDBToolBase { } private updateMetadata(): void { - if (this.config.connectionString || this.session.serviceProvider) { + if (this.session.isConnectedToMongoDB) { this.update?.({ name: connectedName, description: connectedDescription, diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index 83fc85ab..7071b818 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -16,7 +16,7 @@ export abstract class MongoDBToolBase extends ToolBase { public category: ToolCategory = "mongodb"; protected async ensureConnected(): Promise { - if (!this.session.serviceProvider) { + if (!this.session.isConnectedToMongoDB) { if (this.session.connectedAtlasCluster) { throw new MongoDBError( ErrorCodes.NotConnectedToMongoDB, @@ -38,7 +38,7 @@ export abstract class MongoDBToolBase extends ToolBase { } } - if (!this.session.serviceProvider) { + if (!this.session.isConnectedToMongoDB) { throw new MongoDBError(ErrorCodes.NotConnectedToMongoDB, "Not connected to MongoDB"); } @@ -117,7 +117,7 @@ export abstract class MongoDBToolBase extends ToolBase { } protected connectToMongoDB(connectionString: string): Promise { - return this.session.connectToMongoDB(connectionString, this.config.connectOptions); + return this.session.connectToMongoDB({ connectionString, ...this.config.connectOptions }); } protected resolveTelemetryMetadata( diff --git a/tests/integration/common/connectionManager.test.ts b/tests/integration/common/connectionManager.test.ts new file mode 100644 index 00000000..fed5ac48 --- /dev/null +++ b/tests/integration/common/connectionManager.test.ts @@ -0,0 +1,167 @@ +import { + ConnectionManager, + ConnectionManagerEvents, + ConnectionStateConnected, + ConnectionStringAuthType, +} from "../../../src/common/connectionManager.js"; +import { describeWithMongoDB } from "../tools/mongodb/mongodbHelpers.js"; +import { describe, beforeEach, expect, it, vi, afterEach } from "vitest"; +import { config } from "../../../src/common/config.js"; + +describeWithMongoDB("Connection Manager", (integration) => { + function connectionManager() { + return integration.mcpServer().session.connectionManager; + } + + afterEach(async () => { + // disconnect on purpose doesn't change the state if it was failed to avoid losing + // information in production. + await connectionManager().disconnect(); + // for testing, force disconnecting AND setting the connection to closed to reset the + // state of the connection manager + connectionManager().changeState("connection-closed", { tag: "disconnected" }); + }); + + describe("when successfully connected", () => { + type ConnectionManagerSpies = { + "connection-requested": (event: ConnectionManagerEvents["connection-requested"][0]) => void; + "connection-succeeded": (event: ConnectionManagerEvents["connection-succeeded"][0]) => void; + "connection-timed-out": (event: ConnectionManagerEvents["connection-timed-out"][0]) => void; + "connection-closed": (event: ConnectionManagerEvents["connection-closed"][0]) => void; + "connection-errored": (event: ConnectionManagerEvents["connection-errored"][0]) => void; + }; + + let connectionManagerSpies: ConnectionManagerSpies; + + beforeEach(async () => { + connectionManagerSpies = { + "connection-requested": vi.fn(), + "connection-succeeded": vi.fn(), + "connection-timed-out": vi.fn(), + "connection-closed": vi.fn(), + "connection-errored": vi.fn(), + }; + + for (const [event, spy] of Object.entries(connectionManagerSpies)) { + connectionManager().on(event as keyof ConnectionManagerEvents, spy); + } + + await connectionManager().connect({ + connectionString: integration.connectionString(), + ...integration.mcpServer().userConfig.connectOptions, + }); + }); + + it("should be marked explicitly as connected", () => { + expect(connectionManager().currentConnectionState.tag).toEqual("connected"); + }); + + it("can query mongodb successfully", async () => { + const connectionState = connectionManager().currentConnectionState as ConnectionStateConnected; + const collections = await connectionState.serviceProvider.listCollections("admin"); + expect(collections).not.toBe([]); + }); + + it("should notify that the connection was requested", () => { + expect(connectionManagerSpies["connection-requested"]).toHaveBeenCalledOnce(); + }); + + it("should notify that the connection was successful", () => { + expect(connectionManagerSpies["connection-succeeded"]).toHaveBeenCalledOnce(); + }); + + describe("when disconnects", () => { + beforeEach(async () => { + await connectionManager().disconnect(); + }); + + it("should notify that it was disconnected before connecting", () => { + expect(connectionManagerSpies["connection-closed"]).toHaveBeenCalled(); + }); + + it("should be marked explicitly as disconnected", () => { + expect(connectionManager().currentConnectionState.tag).toEqual("disconnected"); + }); + }); + + describe("when reconnects", () => { + beforeEach(async () => { + await connectionManager().connect({ + connectionString: integration.connectionString(), + ...integration.mcpServer().userConfig.connectOptions, + }); + }); + + it("should notify that it was disconnected before connecting", () => { + expect(connectionManagerSpies["connection-closed"]).toHaveBeenCalled(); + }); + + it("should notify that it was connected again", () => { + expect(connectionManagerSpies["connection-succeeded"]).toHaveBeenCalled(); + }); + + it("should be marked explicitly as connected", () => { + expect(connectionManager().currentConnectionState.tag).toEqual("connected"); + }); + }); + + describe("when fails to connect to a new cluster", () => { + beforeEach(async () => { + try { + await connectionManager().connect({ + connectionString: "mongodb://localhost:xxxxx", + ...integration.mcpServer().userConfig.connectOptions, + }); + } catch (_error: unknown) { + void _error; + } + }); + + it("should notify that it was disconnected before connecting", () => { + expect(connectionManagerSpies["connection-closed"]).toHaveBeenCalled(); + }); + + it("should notify that it failed connecting", () => { + expect(connectionManagerSpies["connection-errored"]).toHaveBeenCalled(); + }); + + it("should be marked explicitly as connected", () => { + expect(connectionManager().currentConnectionState.tag).toEqual("errored"); + }); + }); + }); + + describe("when disconnected", () => { + it("should be marked explicitly as disconnected", () => { + expect(connectionManager().currentConnectionState.tag).toEqual("disconnected"); + }); + }); +}); + +describe("Connection Manager connection type inference", () => { + const testCases = [ + { connectionString: "mongodb://localhost:27017", connectionType: "scram" }, + { connectionString: "mongodb://localhost:27017?authMechanism=MONGODB-X509", connectionType: "x.509" }, + { connectionString: "mongodb://localhost:27017?authMechanism=GSSAPI", connectionType: "kerberos" }, + { + connectionString: "mongodb://localhost:27017?authMechanism=PLAIN&authSource=$external", + connectionType: "ldap", + }, + { connectionString: "mongodb://localhost:27017?authMechanism=PLAIN", connectionType: "scram" }, + { connectionString: "mongodb://localhost:27017?authMechanism=MONGODB-OIDC", connectionType: "oidc-auth-flow" }, + ] as { + connectionString: string; + connectionType: ConnectionStringAuthType; + }[]; + + for (const { connectionString, connectionType } of testCases) { + it(`infers ${connectionType} from ${connectionString}`, () => { + const actualConnectionType = ConnectionManager.inferConnectionTypeFromSettings({ + connectionString, + ...config.connectOptions, + }); + + expect(actualConnectionType).toBe(connectionType); + }); + } +}); diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 3a3b0525..f5a6ab7f 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -8,6 +8,7 @@ import { Session } from "../../src/common/session.js"; import { Telemetry } from "../../src/telemetry/telemetry.js"; import { config } from "../../src/common/config.js"; import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest"; +import { ConnectionManager } from "../../src/common/connectionManager.js"; interface ParameterInfo { name: string; @@ -53,10 +54,13 @@ export function setupIntegrationTest(getUserConfig: () => UserConfig): Integrati } ); + const connectionManager = new ConnectionManager(); + const session = new Session({ apiBaseUrl: userConfig.apiBaseUrl, apiClientId: userConfig.apiClientId, apiClientSecret: userConfig.apiClientSecret, + connectionManager, }); // Mock hasValidAccessToken for tests diff --git a/tests/integration/tools/mongodb/connect/connect.test.ts b/tests/integration/tools/mongodb/connect/connect.test.ts index d8be8e5a..8e9d20f3 100644 --- a/tests/integration/tools/mongodb/connect/connect.test.ts +++ b/tests/integration/tools/mongodb/connect/connect.test.ts @@ -12,8 +12,11 @@ import { beforeEach, describe, expect, it } from "vitest"; describeWithMongoDB( "SwitchConnection tool", (integration) => { - beforeEach(() => { - integration.mcpServer().userConfig.connectionString = integration.connectionString(); + beforeEach(async () => { + await integration.mcpServer().session.connectToMongoDB({ + connectionString: integration.connectionString(), + ...config.connectOptions, + }); }); validateToolMetadata( @@ -75,7 +78,7 @@ describeWithMongoDB( const content = getResponseContent(response.content); - expect(content).toContain("Error running switch-connection"); + expect(content).toContain("The configured connection string is not valid."); }); }); }, @@ -125,7 +128,7 @@ describeWithMongoDB( arguments: { connectionString: "mongodb://localhost:12345" }, }); const content = getResponseContent(response.content); - expect(content).toContain("Error running connect"); + expect(content).toContain("The configured connection string is not valid."); // Should not suggest using the config connection string (because we don't have one) expect(content).not.toContain("Your config lists a different connection string"); diff --git a/tests/integration/transports/stdio.test.ts b/tests/integration/transports/stdio.test.ts index 2bc03b5b..6b08e4e6 100644 --- a/tests/integration/transports/stdio.test.ts +++ b/tests/integration/transports/stdio.test.ts @@ -1,8 +1,9 @@ import { describe, expect, it, beforeAll, afterAll } from "vitest"; import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { describeWithMongoDB } from "../tools/mongodb/mongodbHelpers.js"; -describe("StdioRunner", () => { +describeWithMongoDB("StdioRunner", (integration) => { describe("client connects successfully", () => { let client: Client; let transport: StdioClientTransport; @@ -12,6 +13,7 @@ describe("StdioRunner", () => { args: ["dist/index.js"], env: { MDB_MCP_TRANSPORT: "stdio", + MDB_MCP_CONNECTION_STRING: integration.connectionString(), }, }); client = new Client({ diff --git a/tests/unit/common/session.test.ts b/tests/unit/common/session.test.ts index 73236c5f..f96952fe 100644 --- a/tests/unit/common/session.test.ts +++ b/tests/unit/common/session.test.ts @@ -43,7 +43,10 @@ describe("Session", () => { for (const testCase of testCases) { it(`should update connection string for ${testCase.name}`, async () => { - await session.connectToMongoDB(testCase.connectionString, config.connectOptions); + await session.connectToMongoDB({ + connectionString: testCase.connectionString, + ...config.connectOptions, + }); expect(session.serviceProvider).toBeDefined(); const connectMock = MockNodeDriverServiceProvider.connect; @@ -58,7 +61,7 @@ describe("Session", () => { } it("should configure the proxy to use environment variables", async () => { - await session.connectToMongoDB("mongodb://localhost", config.connectOptions); + await session.connectToMongoDB({ connectionString: "mongodb://localhost", ...config.connectOptions }); expect(session.serviceProvider).toBeDefined(); const connectMock = MockNodeDriverServiceProvider.connect;