Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

99 changes: 59 additions & 40 deletions src/common/connectionManager.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import type { UserConfig, DriverOptions } from "./config.js";
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import EventEmitter from "events";
import { setAppNameParamIfMissing } from "../helpers/connectionOptions.js";
import { packageInfo } from "./packageInfo.js";
import ConnectionString from "mongodb-connection-string-url";
import { EventEmitter } from "events";
import type { MongoClientOptions } from "mongodb";
import { ErrorCodes, MongoDBError } from "./errors.js";
import ConnectionString from "mongodb-connection-string-url";
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import { type ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
import type { DeviceId } from "../helpers/deviceId.js";
import type { AppNameComponents } from "../helpers/connectionOptions.js";
import type { CompositeLogger } from "./logger.js";
import { LogId } from "./logger.js";
import type { ConnectionInfo } from "@mongosh/arg-parser";
import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
import type { DriverOptions, UserConfig } from "./config.js";
import { MongoDBError, ErrorCodes } from "./errors.js";
import { type CompositeLogger, LogId } from "./logger.js";
import { packageInfo } from "./packageInfo.js";
import { type AppNameComponents, setAppNameParamIfMissing } from "../helpers/connectionOptions.js";

export interface AtlasClusterConnectionInfo {
username: string;
Expand Down Expand Up @@ -71,10 +68,52 @@ export interface ConnectionManagerEvents {
"connection-errored": [ConnectionStateErrored];
}

export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
private state: AnyConnectionState;
/**
* For a few tests, we need the changeState method to force a connection state
* which is we have this type to typecast the actual ConnectionManager with
* public changeState (only to make TS happy).
*/
export type TestConnectionManager = ConnectionManager & {
changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State;
};

export abstract class ConnectionManager {
protected clientName: string = "unknown";

protected readonly _events = new EventEmitter<ConnectionManagerEvents>();
readonly events: Pick<EventEmitter<ConnectionManagerEvents>, "on" | "off" | "once"> = this._events;

protected state: AnyConnectionState = { tag: "disconnected" };

get currentConnectionState(): AnyConnectionState {
return this.state;
}

protected changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State {
this.state = newState;
// TypeScript doesn't seem to be happy with the spread operator and generics
// eslint-disable-next-line
this._events.emit(event, ...([newState] as any));
return newState;
}

setClientName(clientName: string): void {
this.clientName = clientName;
}

abstract connect(settings: ConnectionSettings): Promise<AnyConnectionState>;

abstract disconnect(): Promise<ConnectionStateDisconnected | ConnectionStateErrored>;
}

export class MCPConnectionManager extends ConnectionManager {
private deviceId: DeviceId;
private clientName: string;
private bus: EventEmitter;

constructor(
Expand All @@ -85,23 +124,15 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
bus?: EventEmitter
) {
super();

this.bus = bus ?? new EventEmitter();
this.state = { tag: "disconnected" };

this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this));
this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this));

this.deviceId = deviceId;
this.clientName = "unknown";
}

setClientName(clientName: string): void {
this.clientName = clientName;
}

async connect(settings: ConnectionSettings): Promise<AnyConnectionState> {
this.emit("connection-requested", this.state);
this._events.emit("connection-requested", this.state);

if (this.state.tag === "connected" || this.state.tag === "connecting") {
await this.disconnect();
Expand Down Expand Up @@ -158,7 +189,10 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
}

try {
const connectionType = ConnectionManager.inferConnectionTypeFromSettings(this.userConfig, connectionInfo);
const connectionType = MCPConnectionManager.inferConnectionTypeFromSettings(
this.userConfig,
connectionInfo
);
if (connectionType.startsWith("oidc")) {
void this.pingAndForget(serviceProvider);

Expand Down Expand Up @@ -208,21 +242,6 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
return { tag: "disconnected" };
}

get currentConnectionState(): AnyConnectionState {
return this.state;
}

changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State {
this.state = newState;
// TypeScript doesn't seem to be happy with the spread operator and generics
// eslint-disable-next-line
this.emit(event, ...([newState] as any));
return newState;
}

private onOidcAuthFailed(error: unknown): void {
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
void this.disconnectOnOidcError(error);
Expand Down
12 changes: 8 additions & 4 deletions src/common/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@ export class Session extends EventEmitter<SessionEvents> {
this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger);
this.exportsManager = exportsManager;
this.connectionManager = connectionManager;
this.connectionManager.on("connection-succeeded", () => this.emit("connect"));
this.connectionManager.on("connection-timed-out", (error) => this.emit("connection-error", error.errorReason));
this.connectionManager.on("connection-closed", () => this.emit("disconnect"));
this.connectionManager.on("connection-errored", (error) => this.emit("connection-error", error.errorReason));
this.connectionManager.events.on("connection-succeeded", () => this.emit("connect"));
this.connectionManager.events.on("connection-timed-out", (error) =>
this.emit("connection-error", error.errorReason)
);
this.connectionManager.events.on("connection-closed", () => this.emit("disconnect"));
this.connectionManager.events.on("connection-errored", (error) =>
this.emit("connection-error", error.errorReason)
);
}

setMcpClient(mcpClient: Implementation | undefined): void {
Expand Down
9 changes: 7 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,22 @@ import { packageInfo } from "./common/packageInfo.js";
import { StdioRunner } from "./transports/stdio.js";
import { StreamableHttpRunner } from "./transports/streamableHttp.js";
import { systemCA } from "@mongodb-js/devtools-proxy-support";
import type { ConnectionManagerFactoryFn } from "./transports/base.js";
import { MCPConnectionManager } from "./common/connectionManager.js";

async function main(): Promise<void> {
systemCA().catch(() => undefined); // load system CA asynchronously as in mongosh

assertHelpMode();
assertVersionMode();

const createConnectionManager: ConnectionManagerFactoryFn = ({ logger, deviceId }) =>
Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId));

const transportRunner =
config.transport === "stdio"
? new StdioRunner(config, driverOptions)
: new StreamableHttpRunner(config, driverOptions);
? new StdioRunner(config, createConnectionManager)
: new StreamableHttpRunner(config, createConnectionManager);
const shutdown = (): void => {
transportRunner.logger.info({
id: LogId.serverCloseRequested,
Expand Down
17 changes: 13 additions & 4 deletions src/lib.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
export { Server, type ServerOptions } from "./server.js";
export { Telemetry } from "./telemetry/telemetry.js";
export { Session, type SessionOptions } from "./common/session.js";
export { type UserConfig, defaultUserConfig } from "./common/config.js";
export { defaultUserConfig, type UserConfig } from "./common/config.js";
export { LoggerBase, CompositeLogger, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js";
export { StreamableHttpRunner } from "./transports/streamableHttp.js";
export { LoggerBase } from "./common/logger.js";
export type { LogPayload, LoggerType, LogLevel } from "./common/logger.js";
export { type ConnectionManagerFactoryFn } from "./transports/base.js";
export {
ConnectionManager,
type AnyConnectionState,
type ConnectionState,
type ConnectionStateConnected,
type ConnectionStateConnecting,
type ConnectionStateDisconnected,
type ConnectionStateErrored,
} from "./common/connectionManager.js";
export { Telemetry } from "./telemetry/telemetry.js";
15 changes: 10 additions & 5 deletions src/transports/base.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { DriverOptions, UserConfig } from "../common/config.js";
import type { UserConfig } from "../common/config.js";
import { packageInfo } from "../common/packageInfo.js";
import { Server } from "../server.js";
import { Session } from "../common/session.js";
Expand All @@ -7,16 +7,21 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { LoggerBase } from "../common/logger.js";
import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js";
import { ExportsManager } from "../common/exportsManager.js";
import { ConnectionManager } from "../common/connectionManager.js";
import type { ConnectionManager } from "../common/connectionManager.js";
import { DeviceId } from "../helpers/deviceId.js";

export type ConnectionManagerFactoryFn = (createParams: {
logger: CompositeLogger;
deviceId: DeviceId;
}) => Promise<ConnectionManager>;

export abstract class TransportRunnerBase {
public logger: LoggerBase;
public deviceId: DeviceId;

protected constructor(
protected readonly userConfig: UserConfig,
private readonly driverOptions: DriverOptions,
private readonly createConnectionManager: ConnectionManagerFactoryFn,
additionalLoggers: LoggerBase[]
) {
const loggers: LoggerBase[] = [...additionalLoggers];
Expand All @@ -38,15 +43,15 @@ export abstract class TransportRunnerBase {
this.deviceId = DeviceId.create(this.logger);
}

protected setupServer(): Server {
protected async setupServer(): Promise<Server> {
const mcpServer = new McpServer({
name: packageInfo.mcpServerName,
version: packageInfo.version,
});

const logger = new CompositeLogger(this.logger);
const exportsManager = ExportsManager.init(this.userConfig, logger);
const connectionManager = new ConnectionManager(this.userConfig, this.driverOptions, logger, this.deviceId);
const connectionManager = await this.createConnectionManager({ logger, deviceId: this.deviceId });

const session = new Session({
apiBaseUrl: this.userConfig.apiBaseUrl,
Expand Down
21 changes: 12 additions & 9 deletions src/transports/stdio.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import type { LoggerBase } from "../common/logger.js";
import { LogId } from "../common/logger.js";
import type { Server } from "../server.js";
import { TransportRunnerBase } from "./base.js";
import { EJSON } from "bson";
import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js";
import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js";
import { EJSON } from "bson";
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
import type { DriverOptions, UserConfig } from "../common/config.js";
import { type LoggerBase, LogId } from "../common/logger.js";
import type { Server } from "../server.js";
import { type ConnectionManagerFactoryFn, TransportRunnerBase } from "./base.js";
import { type UserConfig } from "../common/config.js";

// This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk
// but it uses EJSON.parse instead of JSON.parse to handle BSON types
Expand Down Expand Up @@ -55,13 +54,17 @@ export function createStdioTransport(): StdioServerTransport {
export class StdioRunner extends TransportRunnerBase {
private server: Server | undefined;

constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) {
super(userConfig, driverOptions, additionalLoggers);
constructor(
userConfig: UserConfig,
createConnectionManager: ConnectionManagerFactoryFn,
additionalLoggers: LoggerBase[] = []
) {
super(userConfig, createConnectionManager, additionalLoggers);
}

async start(): Promise<void> {
try {
this.server = this.setupServer();
this.server = await this.setupServer();

const transport = createStdioTransport();

Expand Down
23 changes: 13 additions & 10 deletions src/transports/streamableHttp.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import express from "express";
import type http from "http";
import { randomUUID } from "crypto";
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js";
import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js";
import { TransportRunnerBase } from "./base.js";
import type { DriverOptions, UserConfig } from "../common/config.js";
import type { LoggerBase } from "../common/logger.js";
import { LogId } from "../common/logger.js";
import { randomUUID } from "crypto";
import { LogId, type LoggerBase } from "../common/logger.js";
import { type UserConfig } from "../common/config.js";
import { SessionStore } from "../common/sessionStore.js";
import { type ConnectionManagerFactoryFn, TransportRunnerBase } from "./base.js";

const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000;
const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001;
Expand All @@ -19,6 +18,14 @@ export class StreamableHttpRunner extends TransportRunnerBase {
private httpServer: http.Server | undefined;
private sessionStore!: SessionStore;

constructor(
userConfig: UserConfig,
createConnectionManager: ConnectionManagerFactoryFn,
additionalLoggers: LoggerBase[] = []
) {
super(userConfig, createConnectionManager, additionalLoggers);
}

public get serverAddress(): string {
const result = this.httpServer?.address();
if (typeof result === "string") {
Expand All @@ -31,10 +38,6 @@ export class StreamableHttpRunner extends TransportRunnerBase {
throw new Error("Server is not started yet");
}

constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) {
super(userConfig, driverOptions, additionalLoggers);
}

async start(): Promise<void> {
const app = express();
this.sessionStore = new SessionStore(
Expand Down Expand Up @@ -113,7 +116,7 @@ export class StreamableHttpRunner extends TransportRunnerBase {
return;
}

const server = this.setupServer();
const server = await this.setupServer();
let keepAliveLoop: NodeJS.Timeout;
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: (): string => randomUUID().toString(),
Expand Down
20 changes: 12 additions & 8 deletions tests/integration/build.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ describe("Build Test", () => {
const esmKeys = Object.keys(esmModule).sort();

expect(cjsKeys).toEqual(esmKeys);
expect(cjsKeys).toIncludeSameMembers([
"Server",
"Session",
"Telemetry",
"StreamableHttpRunner",
"defaultUserConfig",
"LoggerBase",
]);
expect(cjsKeys).toEqual(
expect.arrayContaining([
"CompositeLogger",
"ConnectionManager",
"LoggerBase",
"Server",
"Session",
"StreamableHttpRunner",
"Telemetry",
"defaultUserConfig",
])
);
});
});
Loading
Loading