diff --git a/eslint-rules/no-config-imports.js b/eslint-rules/no-config-imports.js index 908dd5ae..5c4efb7c 100644 --- a/eslint-rules/no-config-imports.js +++ b/eslint-rules/no-config-imports.js @@ -10,6 +10,10 @@ const allowedConfigValueImportFiles = [ "src/index.ts", // Config resource definition that works with the some config values "src/resources/common/config.ts", + // The file exports, a factory function to create MCPConnectionManager and + // it relies on driver options generator and default driver options from + // config file. + "src/common/connectionManager.ts", ]; // Ref: https://eslint.org/docs/latest/extend/custom-rules diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index edb0b56e..78e51edb 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -1,17 +1,14 @@ -import type { UserConfig, DriverOptions } from "./config.js"; -import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; -import EventEmitter from "events"; -import { setAppNameParamIfMissing } from "../helpers/connectionOptions.js"; -import { packageInfo } from "./packageInfo.js"; -import ConnectionString from "mongodb-connection-string-url"; +import { EventEmitter } from "events"; import type { MongoClientOptions } from "mongodb"; -import { ErrorCodes, MongoDBError } from "./errors.js"; +import ConnectionString from "mongodb-connection-string-url"; +import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; +import { type ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; import type { DeviceId } from "../helpers/deviceId.js"; -import type { AppNameComponents } from "../helpers/connectionOptions.js"; -import type { CompositeLogger } from "./logger.js"; -import { LogId } from "./logger.js"; -import type { ConnectionInfo } from "@mongosh/arg-parser"; -import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; +import { defaultDriverOptions, setupDriverConfig, type DriverOptions, type UserConfig } from "./config.js"; +import { MongoDBError, ErrorCodes } from "./errors.js"; +import { type LoggerBase, LogId } from "./logger.js"; +import { packageInfo } from "./packageInfo.js"; +import { type AppNameComponents, setAppNameParamIfMissing } from "../helpers/connectionOptions.js"; export interface AtlasClusterConnectionInfo { username: string; @@ -71,39 +68,76 @@ export interface ConnectionManagerEvents { "connection-error": [ConnectionStateErrored]; } -export class ConnectionManager extends EventEmitter { +/** + * For a few tests, we need the changeState method to force a connection state + * which is we have this type to typecast the actual ConnectionManager with + * public changeState (only to make TS happy). + */ +export type TestConnectionManager = ConnectionManager & { + changeState( + event: Event, + newState: State + ): State; +}; + +export abstract class ConnectionManager { + protected clientName: string; + protected readonly _events; + readonly events: Pick, "on" | "off" | "once">; private state: AnyConnectionState; + + constructor() { + this.clientName = "unknown"; + this.events = this._events = new EventEmitter(); + this.state = { tag: "disconnected" }; + } + + get currentConnectionState(): AnyConnectionState { + return this.state; + } + + protected changeState( + event: Event, + newState: State + ): State { + this.state = newState; + // TypeScript doesn't seem to be happy with the spread operator and generics + // eslint-disable-next-line + this._events.emit(event, ...([newState] as any)); + return newState; + } + + setClientName(clientName: string): void { + this.clientName = clientName; + } + + abstract connect(settings: ConnectionSettings): Promise; + + abstract disconnect(): Promise; +} + +export class MCPConnectionManager extends ConnectionManager { private deviceId: DeviceId; - private clientName: string; private bus: EventEmitter; constructor( private userConfig: UserConfig, private driverOptions: DriverOptions, - private logger: CompositeLogger, + private logger: LoggerBase, deviceId: DeviceId, bus?: EventEmitter ) { super(); - this.bus = bus ?? new EventEmitter(); - this.state = { tag: "disconnected" }; - this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this)); this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this)); - this.deviceId = deviceId; - this.clientName = "unknown"; - } - - setClientName(clientName: string): void { - this.clientName = clientName; } async connect(settings: ConnectionSettings): Promise { - this.emit("connection-request", this.state); + this._events.emit("connection-request", this.currentConnectionState); - if (this.state.tag === "connected" || this.state.tag === "connecting") { + if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") { await this.disconnect(); } @@ -138,7 +172,7 @@ export class ConnectionManager extends EventEmitter { connectionInfo.driverOptions.proxy ??= { useEnvironmentVariableProxies: true }; connectionInfo.driverOptions.applyProxyToOIDC ??= true; - connectionStringAuthType = ConnectionManager.inferConnectionTypeFromSettings( + connectionStringAuthType = MCPConnectionManager.inferConnectionTypeFromSettings( this.userConfig, connectionInfo ); @@ -165,7 +199,10 @@ export class ConnectionManager extends EventEmitter { } try { - const connectionType = ConnectionManager.inferConnectionTypeFromSettings(this.userConfig, connectionInfo); + const connectionType = MCPConnectionManager.inferConnectionTypeFromSettings( + this.userConfig, + connectionInfo + ); if (connectionType.startsWith("oidc")) { void this.pingAndForget(serviceProvider); @@ -199,13 +236,13 @@ export class ConnectionManager extends EventEmitter { } async disconnect(): Promise { - if (this.state.tag === "disconnected" || this.state.tag === "errored") { - return this.state; + if (this.currentConnectionState.tag === "disconnected" || this.currentConnectionState.tag === "errored") { + return this.currentConnectionState; } - if (this.state.tag === "connected" || this.state.tag === "connecting") { + if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") { try { - await this.state.serviceProvider?.close(true); + await this.currentConnectionState.serviceProvider?.close(true); } finally { this.changeState("connection-close", { tag: "disconnected", @@ -216,30 +253,21 @@ export class ConnectionManager extends EventEmitter { return { tag: "disconnected" }; } - get currentConnectionState(): AnyConnectionState { - return this.state; - } - - changeState( - event: Event, - newState: State - ): State { - this.state = newState; - // TypeScript doesn't seem to be happy with the spread operator and generics - // eslint-disable-next-line - this.emit(event, ...([newState] as any)); - return newState; - } - private onOidcAuthFailed(error: unknown): void { - if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + if ( + this.currentConnectionState.tag === "connecting" && + this.currentConnectionState.connectionStringAuthType?.startsWith("oidc") + ) { void this.disconnectOnOidcError(error); } } private onOidcAuthSucceeded(): void { - if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { - this.changeState("connection-success", { ...this.state, tag: "connected" }); + if ( + this.currentConnectionState.tag === "connecting" && + this.currentConnectionState.connectionStringAuthType?.startsWith("oidc") + ) { + this.changeState("connection-success", { ...this.currentConnectionState, tag: "connected" }); } this.logger.info({ @@ -250,9 +278,12 @@ export class ConnectionManager extends EventEmitter { } private onOidcNotifyDeviceFlow(flowInfo: { verificationUrl: string; userCode: string }): void { - if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) { + if ( + this.currentConnectionState.tag === "connecting" && + this.currentConnectionState.connectionStringAuthType?.startsWith("oidc") + ) { this.changeState("connection-request", { - ...this.state, + ...this.currentConnectionState, tag: "connecting", connectionStringAuthType: "oidc-device-flow", oidcLoginUrl: flowInfo.verificationUrl, @@ -329,3 +360,23 @@ export class ConnectionManager extends EventEmitter { } } } + +/** + * Consumers of MCP server library have option to bring their own connection + * management if they need to. To support that, we enable injecting connection + * manager implementation through a factory function. + */ +export type ConnectionManagerFactoryFn = (createParams: { + logger: LoggerBase; + deviceId: DeviceId; + userConfig: UserConfig; +}) => Promise; + +export const createMCPConnectionManager: ConnectionManagerFactoryFn = ({ logger, deviceId, userConfig }) => { + const driverOptions = setupDriverConfig({ + config: userConfig, + defaults: defaultDriverOptions, + }); + + return Promise.resolve(new MCPConnectionManager(userConfig, driverOptions, logger, deviceId)); +}; diff --git a/src/common/session.ts b/src/common/session.ts index 87113894..c1f7b5a1 100644 --- a/src/common/session.ts +++ b/src/common/session.ts @@ -67,10 +67,10 @@ export class Session extends EventEmitter { this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger); this.exportsManager = exportsManager; this.connectionManager = connectionManager; - this.connectionManager.on("connection-success", () => this.emit("connect")); - this.connectionManager.on("connection-time-out", (error) => this.emit("connection-error", error)); - this.connectionManager.on("connection-close", () => this.emit("disconnect")); - this.connectionManager.on("connection-error", (error) => this.emit("connection-error", error)); + this.connectionManager.events.on("connection-success", () => this.emit("connect")); + this.connectionManager.events.on("connection-time-out", (error) => this.emit("connection-error", error)); + this.connectionManager.events.on("connection-close", () => this.emit("disconnect")); + this.connectionManager.events.on("connection-error", (error) => this.emit("connection-error", error)); } setMcpClient(mcpClient: Implementation | undefined): void { diff --git a/src/index.ts b/src/index.ts index b1ac4b48..6a7150e3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -36,7 +36,7 @@ function enableFipsIfRequested(): void { enableFipsIfRequested(); import { ConsoleLogger, LogId } from "./common/logger.js"; -import { config, driverOptions } from "./common/config.js"; +import { config } from "./common/config.js"; import crypto from "crypto"; import { packageInfo } from "./common/packageInfo.js"; import { StdioRunner } from "./transports/stdio.js"; @@ -49,10 +49,7 @@ async function main(): Promise { assertHelpMode(); assertVersionMode(); - const transportRunner = - config.transport === "stdio" - ? new StdioRunner(config, driverOptions) - : new StreamableHttpRunner(config, driverOptions); + const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config); const shutdown = (): void => { transportRunner.logger.info({ id: LogId.serverCloseRequested, diff --git a/src/lib.ts b/src/lib.ts index 9fd921e4..01dc8b88 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -1,7 +1,14 @@ export { Server, type ServerOptions } from "./server.js"; -export { Telemetry } from "./telemetry/telemetry.js"; export { Session, type SessionOptions } from "./common/session.js"; -export { type UserConfig, defaultUserConfig } from "./common/config.js"; +export { defaultUserConfig, type UserConfig } from "./common/config.js"; +export { LoggerBase, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js"; export { StreamableHttpRunner } from "./transports/streamableHttp.js"; -export { LoggerBase } from "./common/logger.js"; -export type { LogPayload, LoggerType, LogLevel } from "./common/logger.js"; +export { + ConnectionManager, + type AnyConnectionState, + type ConnectionState, + type ConnectionStateDisconnected, + type ConnectionStateErrored, + type ConnectionManagerFactoryFn, +} from "./common/connectionManager.js"; +export { Telemetry } from "./telemetry/telemetry.js"; diff --git a/src/transports/base.ts b/src/transports/base.ts index 485752e7..d6fc53ad 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -1,4 +1,4 @@ -import type { DriverOptions, UserConfig } from "../common/config.js"; +import type { UserConfig } from "../common/config.js"; import { packageInfo } from "../common/packageInfo.js"; import { Server } from "../server.js"; import { Session } from "../common/session.js"; @@ -7,8 +7,8 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { LoggerBase } from "../common/logger.js"; import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js"; import { ExportsManager } from "../common/exportsManager.js"; -import { ConnectionManager } from "../common/connectionManager.js"; import { DeviceId } from "../helpers/deviceId.js"; +import { type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; export abstract class TransportRunnerBase { public logger: LoggerBase; @@ -16,7 +16,7 @@ export abstract class TransportRunnerBase { protected constructor( protected readonly userConfig: UserConfig, - private readonly driverOptions: DriverOptions, + private readonly createConnectionManager: ConnectionManagerFactoryFn, additionalLoggers: LoggerBase[] ) { const loggers: LoggerBase[] = [...additionalLoggers]; @@ -38,7 +38,7 @@ export abstract class TransportRunnerBase { this.deviceId = DeviceId.create(this.logger); } - protected setupServer(): Server { + protected async setupServer(): Promise { const mcpServer = new McpServer({ name: packageInfo.mcpServerName, version: packageInfo.version, @@ -46,7 +46,11 @@ export abstract class TransportRunnerBase { const logger = new CompositeLogger(this.logger); const exportsManager = ExportsManager.init(this.userConfig, logger); - const connectionManager = new ConnectionManager(this.userConfig, this.driverOptions, logger, this.deviceId); + const connectionManager = await this.createConnectionManager({ + logger, + userConfig: this.userConfig, + deviceId: this.deviceId, + }); const session = new Session({ apiBaseUrl: this.userConfig.apiBaseUrl, diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 0751cac7..4ed941ef 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -1,12 +1,12 @@ -import type { LoggerBase } from "../common/logger.js"; -import { LogId } from "../common/logger.js"; -import type { Server } from "../server.js"; -import { TransportRunnerBase } from "./base.js"; +import { EJSON } from "bson"; import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; -import { EJSON } from "bson"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import type { DriverOptions, UserConfig } from "../common/config.js"; +import { type LoggerBase, LogId } from "../common/logger.js"; +import type { Server } from "../server.js"; +import { TransportRunnerBase } from "./base.js"; +import { type UserConfig } from "../common/config.js"; +import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk // but it uses EJSON.parse instead of JSON.parse to handle BSON types @@ -55,13 +55,17 @@ export function createStdioTransport(): StdioServerTransport { export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; - constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) { - super(userConfig, driverOptions, additionalLoggers); + constructor( + userConfig: UserConfig, + createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager, + additionalLoggers: LoggerBase[] = [] + ) { + super(userConfig, createConnectionManager, additionalLoggers); } async start(): Promise { try { - this.server = this.setupServer(); + this.server = await this.setupServer(); const transport = createStdioTransport(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index 1718252c..4e8aebb8 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -1,13 +1,13 @@ import express from "express"; import type http from "http"; +import { randomUUID } from "crypto"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; -import { TransportRunnerBase } from "./base.js"; -import type { DriverOptions, UserConfig } from "../common/config.js"; -import type { LoggerBase } from "../common/logger.js"; -import { LogId } from "../common/logger.js"; -import { randomUUID } from "crypto"; +import { LogId, type LoggerBase } from "../common/logger.js"; +import { type UserConfig } from "../common/config.js"; import { SessionStore } from "../common/sessionStore.js"; +import { TransportRunnerBase } from "./base.js"; +import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; @@ -19,6 +19,14 @@ export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; private sessionStore!: SessionStore; + constructor( + userConfig: UserConfig, + createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager, + additionalLoggers: LoggerBase[] = [] + ) { + super(userConfig, createConnectionManager, additionalLoggers); + } + public get serverAddress(): string { const result = this.httpServer?.address(); if (typeof result === "string") { @@ -31,10 +39,6 @@ export class StreamableHttpRunner extends TransportRunnerBase { throw new Error("Server is not started yet"); } - constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) { - super(userConfig, driverOptions, additionalLoggers); - } - async start(): Promise { const app = express(); this.sessionStore = new SessionStore( @@ -113,7 +117,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { return; } - const server = this.setupServer(); + const server = await this.setupServer(); let keepAliveLoop: NodeJS.Timeout; const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: (): string => randomUUID().toString(), diff --git a/tests/integration/build.test.ts b/tests/integration/build.test.ts index f5b26827..7453cb3d 100644 --- a/tests/integration/build.test.ts +++ b/tests/integration/build.test.ts @@ -41,13 +41,16 @@ describe("Build Test", () => { const esmKeys = Object.keys(esmModule).sort(); expect(cjsKeys).toEqual(esmKeys); - expect(cjsKeys).toIncludeSameMembers([ - "Server", - "Session", - "Telemetry", - "StreamableHttpRunner", - "defaultUserConfig", - "LoggerBase", - ]); + expect(cjsKeys).toEqual( + expect.arrayContaining([ + "ConnectionManager", + "LoggerBase", + "Server", + "Session", + "StreamableHttpRunner", + "Telemetry", + "defaultUserConfig", + ]) + ); }); }); diff --git a/tests/integration/common/connectionManager.oidc.test.ts b/tests/integration/common/connectionManager.oidc.test.ts index fe65e63d..d4932995 100644 --- a/tests/integration/common/connectionManager.oidc.test.ts +++ b/tests/integration/common/connectionManager.oidc.test.ts @@ -5,7 +5,11 @@ import process from "process"; import type { MongoDBIntegrationTestCase } from "../tools/mongodb/mongodbHelpers.js"; import { describeWithMongoDB, isCommunityServer, getServerVersion } from "../tools/mongodb/mongodbHelpers.js"; import { defaultTestConfig, responseAsText, timeout, waitUntil } from "../helpers.js"; -import type { ConnectionStateConnected, ConnectionStateConnecting } from "../../../src/common/connectionManager.js"; +import type { + ConnectionStateConnected, + ConnectionStateConnecting, + TestConnectionManager, +} from "../../../src/common/connectionManager.js"; import type { UserConfig } from "../../../src/common/config.js"; import { setupDriverConfig } from "../../../src/common/config.js"; import path from "path"; @@ -122,7 +126,8 @@ describe.skipIf(process.platform !== "linux")("ConnectionManager OIDC Tests", as } beforeEach(async () => { - const connectionManager = integration.mcpServer().session.connectionManager; + const connectionManager = integration.mcpServer().session + .connectionManager as TestConnectionManager; // disconnect on purpose doesn't change the state if it was failed to avoid losing // information in production. await connectionManager.disconnect(); diff --git a/tests/integration/common/connectionManager.test.ts b/tests/integration/common/connectionManager.test.ts index 5a8cb6da..9771a1ec 100644 --- a/tests/integration/common/connectionManager.test.ts +++ b/tests/integration/common/connectionManager.test.ts @@ -2,15 +2,16 @@ import type { ConnectionManagerEvents, ConnectionStateConnected, ConnectionStringAuthType, + TestConnectionManager, } from "../../../src/common/connectionManager.js"; -import { ConnectionManager } from "../../../src/common/connectionManager.js"; +import { MCPConnectionManager } from "../../../src/common/connectionManager.js"; import type { UserConfig } from "../../../src/common/config.js"; import { describeWithMongoDB } from "../tools/mongodb/mongodbHelpers.js"; import { describe, beforeEach, expect, it, vi, afterEach } from "vitest"; describeWithMongoDB("Connection Manager", (integration) => { - function connectionManager(): ConnectionManager { - return integration.mcpServer().session.connectionManager; + function connectionManager(): TestConnectionManager { + return integration.mcpServer().session.connectionManager as TestConnectionManager; } afterEach(async () => { @@ -43,7 +44,7 @@ describeWithMongoDB("Connection Manager", (integration) => { }; for (const [event, spy] of Object.entries(connectionManagerSpies)) { - connectionManager().on(event as keyof ConnectionManagerEvents, spy); + connectionManager().events.on(event as keyof ConnectionManagerEvents, spy); } await connectionManager().connect({ @@ -224,9 +225,12 @@ describe("Connection Manager connection type inference", () => { for (const { userConfig, connectionString, connectionType } of testCases) { it(`infers ${connectionType} from ${connectionString}`, () => { - const actualConnectionType = ConnectionManager.inferConnectionTypeFromSettings(userConfig as UserConfig, { - connectionString, - }); + const actualConnectionType = MCPConnectionManager.inferConnectionTypeFromSettings( + userConfig as UserConfig, + { + connectionString, + } + ); expect(actualConnectionType).toBe(connectionType); }); diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index b67fbc16..e4913f6d 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -10,8 +10,8 @@ import type { UserConfig, DriverOptions } from "../../src/common/config.js"; import { McpError, ResourceUpdatedNotificationSchema } from "@modelcontextprotocol/sdk/types.js"; import { config, driverOptions } from "../../src/common/config.js"; import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest"; -import type { ConnectionState } from "../../src/common/connectionManager.js"; -import { ConnectionManager } from "../../src/common/connectionManager.js"; +import type { ConnectionManager, ConnectionState } from "../../src/common/connectionManager.js"; +import { MCPConnectionManager } from "../../src/common/connectionManager.js"; import { DeviceId } from "../../src/helpers/deviceId.js"; interface ParameterInfo { @@ -72,7 +72,7 @@ export function setupIntegrationTest( const exportsManager = ExportsManager.init(userConfig, logger); deviceId = DeviceId.create(logger); - const connectionManager = new ConnectionManager(userConfig, driverOptions, logger, deviceId); + const connectionManager = new MCPConnectionManager(userConfig, driverOptions, logger, deviceId); const session = new Session({ apiBaseUrl: userConfig.apiBaseUrl, diff --git a/tests/integration/telemetry.test.ts b/tests/integration/telemetry.test.ts index cc51ed8b..b63a3796 100644 --- a/tests/integration/telemetry.test.ts +++ b/tests/integration/telemetry.test.ts @@ -4,7 +4,7 @@ import { config, driverOptions } from "../../src/common/config.js"; import { DeviceId } from "../../src/helpers/deviceId.js"; import { describe, expect, it } from "vitest"; import { CompositeLogger } from "../../src/common/logger.js"; -import { ConnectionManager } from "../../src/common/connectionManager.js"; +import { MCPConnectionManager } from "../../src/common/connectionManager.js"; import { ExportsManager } from "../../src/common/exportsManager.js"; describe("Telemetry", () => { @@ -19,7 +19,7 @@ describe("Telemetry", () => { apiBaseUrl: "", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(config, driverOptions, logger, deviceId), + connectionManager: new MCPConnectionManager(config, driverOptions, logger, deviceId), }), config, deviceId diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index f45ce3cd..462ba933 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -2,9 +2,10 @@ import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js" import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { describe, expect, it, beforeAll, afterAll, beforeEach } from "vitest"; -import { config, driverOptions } from "../../../src/common/config.js"; +import { config } from "../../../src/common/config.js"; import type { LoggerType, LogLevel, LogPayload } from "../../../src/common/logger.js"; import { LoggerBase, LogId } from "../../../src/common/logger.js"; +import { createMCPConnectionManager } from "../../../src/common/connectionManager.js"; describe("StreamableHttpRunner", () => { let runner: StreamableHttpRunner; @@ -28,7 +29,7 @@ describe("StreamableHttpRunner", () => { describe(description, () => { beforeAll(async () => { config.httpHeaders = headers; - runner = new StreamableHttpRunner(config, driverOptions); + runner = new StreamableHttpRunner(config); await runner.start(); }); @@ -109,7 +110,7 @@ describe("StreamableHttpRunner", () => { try { for (let i = 0; i < 3; i++) { config.httpPort = 0; // Use a random port for each runner - const runner = new StreamableHttpRunner(config, driverOptions); + const runner = new StreamableHttpRunner(config); await runner.start(); runners.push(runner); } @@ -138,7 +139,7 @@ describe("StreamableHttpRunner", () => { it("can provide custom logger", async () => { const logger = new CustomLogger(); - const runner = new StreamableHttpRunner(config, driverOptions, [logger]); + const runner = new StreamableHttpRunner(config, createMCPConnectionManager, [logger]); await runner.start(); const messages = logger.messages; diff --git a/tests/unit/common/session.test.ts b/tests/unit/common/session.test.ts index 6b2b3552..53129d9e 100644 --- a/tests/unit/common/session.test.ts +++ b/tests/unit/common/session.test.ts @@ -4,7 +4,7 @@ import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver import { Session } from "../../../src/common/session.js"; import { config, driverOptions } from "../../../src/common/config.js"; import { CompositeLogger } from "../../../src/common/logger.js"; -import { ConnectionManager } from "../../../src/common/connectionManager.js"; +import { MCPConnectionManager } from "../../../src/common/connectionManager.js"; import { ExportsManager } from "../../../src/common/exportsManager.js"; import { DeviceId } from "../../../src/helpers/deviceId.js"; @@ -27,7 +27,7 @@ describe("Session", () => { apiBaseUrl: "https://api.test.com", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(config, driverOptions, logger, mockDeviceId), + connectionManager: new MCPConnectionManager(config, driverOptions, logger, mockDeviceId), }); MockNodeDriverServiceProvider.connect = vi.fn().mockResolvedValue({} as unknown as NodeDriverServiceProvider); diff --git a/tests/unit/resources/common/debug.test.ts b/tests/unit/resources/common/debug.test.ts index 52f1dd82..d0621026 100644 --- a/tests/unit/resources/common/debug.test.ts +++ b/tests/unit/resources/common/debug.test.ts @@ -4,7 +4,7 @@ import { Session } from "../../../../src/common/session.js"; import { Telemetry } from "../../../../src/telemetry/telemetry.js"; import { config, driverOptions } from "../../../../src/common/config.js"; import { CompositeLogger } from "../../../../src/common/logger.js"; -import { ConnectionManager } from "../../../../src/common/connectionManager.js"; +import { MCPConnectionManager } from "../../../../src/common/connectionManager.js"; import { ExportsManager } from "../../../../src/common/exportsManager.js"; import { DeviceId } from "../../../../src/helpers/deviceId.js"; @@ -15,7 +15,7 @@ describe("debug resource", () => { apiBaseUrl: "", logger, exportsManager: ExportsManager.init(config, logger), - connectionManager: new ConnectionManager(config, driverOptions, logger, deviceId), + connectionManager: new MCPConnectionManager(config, driverOptions, logger, deviceId), }); const telemetry = Telemetry.create(session, { ...config, telemetry: "disabled" }, deviceId);