Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

132 changes: 82 additions & 50 deletions src/common/connectionManager.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import type { UserConfig, DriverOptions } from "./config.js";
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import EventEmitter from "events";
import { setAppNameParamIfMissing } from "../helpers/connectionOptions.js";
import { packageInfo } from "./packageInfo.js";
import ConnectionString from "mongodb-connection-string-url";
import { EventEmitter } from "events";
import type { MongoClientOptions } from "mongodb";
import { ErrorCodes, MongoDBError } from "./errors.js";
import ConnectionString from "mongodb-connection-string-url";
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import { type ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
import type { DeviceId } from "../helpers/deviceId.js";
import type { AppNameComponents } from "../helpers/connectionOptions.js";
import type { CompositeLogger } from "./logger.js";
import { LogId } from "./logger.js";
import type { ConnectionInfo } from "@mongosh/arg-parser";
import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
import type { DriverOptions, UserConfig } from "./config.js";
import { MongoDBError, ErrorCodes } from "./errors.js";
import { type CompositeLogger, LogId } from "./logger.js";
import { packageInfo } from "./packageInfo.js";
import { type AppNameComponents, setAppNameParamIfMissing } from "../helpers/connectionOptions.js";

export interface AtlasClusterConnectionInfo {
username: string;
Expand Down Expand Up @@ -71,10 +68,56 @@ export interface ConnectionManagerEvents {
"connection-error": [ConnectionStateErrored];
}

export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
/**
* For a few tests, we need the changeState method to force a connection state
* which is we have this type to typecast the actual ConnectionManager with
* public changeState (only to make TS happy).
*/
export type TestConnectionManager = ConnectionManager & {
changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State;
};

export abstract class ConnectionManager {
protected clientName: string;
protected readonly _events;
readonly events: Pick<EventEmitter<ConnectionManagerEvents>, "on" | "off" | "once">;
private state: AnyConnectionState;

constructor() {
this.clientName = "unknown";
this.events = this._events = new EventEmitter<ConnectionManagerEvents>();
this.state = { tag: "disconnected" };
}

get currentConnectionState(): AnyConnectionState {
return this.state;
}

protected changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State {
this.state = newState;
// TypeScript doesn't seem to be happy with the spread operator and generics
// eslint-disable-next-line
this._events.emit(event, ...([newState] as any));
return newState;
}

setClientName(clientName: string): void {
this.clientName = clientName;
}

abstract connect(settings: ConnectionSettings): Promise<AnyConnectionState>;

abstract disconnect(): Promise<ConnectionStateDisconnected | ConnectionStateErrored>;
}

export class MCPConnectionManager extends ConnectionManager {
private deviceId: DeviceId;
private clientName: string;
private bus: EventEmitter;

constructor(
Expand All @@ -85,25 +128,17 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
bus?: EventEmitter
) {
super();

this.bus = bus ?? new EventEmitter();
this.state = { tag: "disconnected" };

this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this));
this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this));

this.deviceId = deviceId;
this.clientName = "unknown";
}

setClientName(clientName: string): void {
this.clientName = clientName;
}

async connect(settings: ConnectionSettings): Promise<AnyConnectionState> {
this.emit("connection-request", this.state);
this._events.emit("connection-request", this.currentConnectionState);

if (this.state.tag === "connected" || this.state.tag === "connecting") {
if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") {
await this.disconnect();
}

Expand Down Expand Up @@ -138,7 +173,7 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
connectionInfo.driverOptions.proxy ??= { useEnvironmentVariableProxies: true };
connectionInfo.driverOptions.applyProxyToOIDC ??= true;

connectionStringAuthType = ConnectionManager.inferConnectionTypeFromSettings(
connectionStringAuthType = MCPConnectionManager.inferConnectionTypeFromSettings(
this.userConfig,
connectionInfo
);
Expand All @@ -165,7 +200,10 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
}

try {
const connectionType = ConnectionManager.inferConnectionTypeFromSettings(this.userConfig, connectionInfo);
const connectionType = MCPConnectionManager.inferConnectionTypeFromSettings(
this.userConfig,
connectionInfo
);
if (connectionType.startsWith("oidc")) {
void this.pingAndForget(serviceProvider);

Expand Down Expand Up @@ -199,13 +237,13 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
}

async disconnect(): Promise<ConnectionStateDisconnected | ConnectionStateErrored> {
if (this.state.tag === "disconnected" || this.state.tag === "errored") {
return this.state;
if (this.currentConnectionState.tag === "disconnected" || this.currentConnectionState.tag === "errored") {
return this.currentConnectionState;
}

if (this.state.tag === "connected" || this.state.tag === "connecting") {
if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") {
try {
await this.state.serviceProvider?.close(true);
await this.currentConnectionState.serviceProvider?.close(true);
} finally {
this.changeState("connection-close", {
tag: "disconnected",
Expand All @@ -216,30 +254,21 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
return { tag: "disconnected" };
}

get currentConnectionState(): AnyConnectionState {
return this.state;
}

changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State {
this.state = newState;
// TypeScript doesn't seem to be happy with the spread operator and generics
// eslint-disable-next-line
this.emit(event, ...([newState] as any));
return newState;
}

private onOidcAuthFailed(error: unknown): void {
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
if (
this.currentConnectionState.tag === "connecting" &&
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
) {
void this.disconnectOnOidcError(error);
}
}

private onOidcAuthSucceeded(): void {
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
this.changeState("connection-success", { ...this.state, tag: "connected" });
if (
this.currentConnectionState.tag === "connecting" &&
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
) {
this.changeState("connection-success", { ...this.currentConnectionState, tag: "connected" });
}

this.logger.info({
Expand All @@ -250,9 +279,12 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
}

private onOidcNotifyDeviceFlow(flowInfo: { verificationUrl: string; userCode: string }): void {
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
if (
this.currentConnectionState.tag === "connecting" &&
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
) {
this.changeState("connection-request", {
...this.state,
...this.currentConnectionState,
tag: "connecting",
connectionStringAuthType: "oidc-device-flow",
oidcLoginUrl: flowInfo.verificationUrl,
Expand Down
8 changes: 4 additions & 4 deletions src/common/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ export class Session extends EventEmitter<SessionEvents> {
this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger);
this.exportsManager = exportsManager;
this.connectionManager = connectionManager;
this.connectionManager.on("connection-success", () => this.emit("connect"));
this.connectionManager.on("connection-time-out", (error) => this.emit("connection-error", error));
this.connectionManager.on("connection-close", () => this.emit("disconnect"));
this.connectionManager.on("connection-error", (error) => this.emit("connection-error", error));
this.connectionManager.events.on("connection-success", () => this.emit("connect"));
this.connectionManager.events.on("connection-time-out", (error) => this.emit("connection-error", error));
this.connectionManager.events.on("connection-close", () => this.emit("disconnect"));
this.connectionManager.events.on("connection-error", (error) => this.emit("connection-error", error));
}

setMcpClient(mcpClient: Implementation | undefined): void {
Expand Down
9 changes: 7 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,22 @@ import { packageInfo } from "./common/packageInfo.js";
import { StdioRunner } from "./transports/stdio.js";
import { StreamableHttpRunner } from "./transports/streamableHttp.js";
import { systemCA } from "@mongodb-js/devtools-proxy-support";
import type { ConnectionManagerFactoryFn } from "./transports/base.js";
import { MCPConnectionManager } from "./common/connectionManager.js";

async function main(): Promise<void> {
systemCA().catch(() => undefined); // load system CA asynchronously as in mongosh

assertHelpMode();
assertVersionMode();

const createConnectionManager: ConnectionManagerFactoryFn = ({ logger, deviceId }) =>
Promise.resolve(new MCPConnectionManager(config, driverOptions, logger, deviceId));

const transportRunner =
config.transport === "stdio"
? new StdioRunner(config, driverOptions)
: new StreamableHttpRunner(config, driverOptions);
? new StdioRunner(config, createConnectionManager)
: new StreamableHttpRunner(config, createConnectionManager);
const shutdown = (): void => {
transportRunner.logger.info({
id: LogId.serverCloseRequested,
Expand Down
17 changes: 13 additions & 4 deletions src/lib.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
export { Server, type ServerOptions } from "./server.js";
export { Telemetry } from "./telemetry/telemetry.js";
export { Session, type SessionOptions } from "./common/session.js";
export { type UserConfig, defaultUserConfig } from "./common/config.js";
export { defaultUserConfig, type UserConfig } from "./common/config.js";
export { LoggerBase, CompositeLogger, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js";
export { StreamableHttpRunner } from "./transports/streamableHttp.js";
export { LoggerBase } from "./common/logger.js";
export type { LogPayload, LoggerType, LogLevel } from "./common/logger.js";
export { type ConnectionManagerFactoryFn } from "./transports/base.js";
export {
ConnectionManager,
type AnyConnectionState,
type ConnectionState,
type ConnectionStateConnected,
type ConnectionStateConnecting,
type ConnectionStateDisconnected,
type ConnectionStateErrored,
} from "./common/connectionManager.js";
export { Telemetry } from "./telemetry/telemetry.js";
15 changes: 10 additions & 5 deletions src/transports/base.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { DriverOptions, UserConfig } from "../common/config.js";
import type { UserConfig } from "../common/config.js";
import { packageInfo } from "../common/packageInfo.js";
import { Server } from "../server.js";
import { Session } from "../common/session.js";
Expand All @@ -7,16 +7,21 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { LoggerBase } from "../common/logger.js";
import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js";
import { ExportsManager } from "../common/exportsManager.js";
import { ConnectionManager } from "../common/connectionManager.js";
import type { ConnectionManager } from "../common/connectionManager.js";
import { DeviceId } from "../helpers/deviceId.js";

export type ConnectionManagerFactoryFn = (createParams: {
logger: CompositeLogger;
deviceId: DeviceId;
}) => Promise<ConnectionManager>;

export abstract class TransportRunnerBase {
public logger: LoggerBase;
public deviceId: DeviceId;

protected constructor(
protected readonly userConfig: UserConfig,
private readonly driverOptions: DriverOptions,
private readonly createConnectionManager: ConnectionManagerFactoryFn,
additionalLoggers: LoggerBase[]
) {
const loggers: LoggerBase[] = [...additionalLoggers];
Expand All @@ -38,15 +43,15 @@ export abstract class TransportRunnerBase {
this.deviceId = DeviceId.create(this.logger);
}

protected setupServer(): Server {
protected async setupServer(): Promise<Server> {
const mcpServer = new McpServer({
name: packageInfo.mcpServerName,
version: packageInfo.version,
});

const logger = new CompositeLogger(this.logger);
const exportsManager = ExportsManager.init(this.userConfig, logger);
const connectionManager = new ConnectionManager(this.userConfig, this.driverOptions, logger, this.deviceId);
const connectionManager = await this.createConnectionManager({ logger, deviceId: this.deviceId });

const session = new Session({
apiBaseUrl: this.userConfig.apiBaseUrl,
Expand Down
21 changes: 12 additions & 9 deletions src/transports/stdio.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import type { LoggerBase } from "../common/logger.js";
import { LogId } from "../common/logger.js";
import type { Server } from "../server.js";
import { TransportRunnerBase } from "./base.js";
import { EJSON } from "bson";
import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js";
import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js";
import { EJSON } from "bson";
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
import type { DriverOptions, UserConfig } from "../common/config.js";
import { type LoggerBase, LogId } from "../common/logger.js";
import type { Server } from "../server.js";
import { type ConnectionManagerFactoryFn, TransportRunnerBase } from "./base.js";
import { type UserConfig } from "../common/config.js";

// This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk
// but it uses EJSON.parse instead of JSON.parse to handle BSON types
Expand Down Expand Up @@ -55,13 +54,17 @@ export function createStdioTransport(): StdioServerTransport {
export class StdioRunner extends TransportRunnerBase {
private server: Server | undefined;

constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) {
super(userConfig, driverOptions, additionalLoggers);
constructor(
userConfig: UserConfig,
createConnectionManager: ConnectionManagerFactoryFn,
additionalLoggers: LoggerBase[] = []
) {
super(userConfig, createConnectionManager, additionalLoggers);
}

async start(): Promise<void> {
try {
this.server = this.setupServer();
this.server = await this.setupServer();

const transport = createStdioTransport();

Expand Down
Loading
Loading