Skip to content

chore: refactor connections to use the new ConnectionManager to isolate long running processes like OIDC connections MCP-81 #423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Aug 6, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3977aa7
chore: refactor to include the new ConnectionManager
kmruiz Aug 4, 2025
7a1e217
chore: add tests to connection manager
kmruiz Aug 5, 2025
ece4d6c
chore: fix typo
kmruiz Aug 5, 2025
06a7ea8
chore: js prefers strict equality
kmruiz Aug 5, 2025
806382f
chore: fix linter errors
kmruiz Aug 5, 2025
4c357f9
chore: connection requested is not necessary at the end
kmruiz Aug 5, 2025
3f52815
chore: add test to connection-requested
kmruiz Aug 5, 2025
28d8b69
chore: add tests for the actual connection status
kmruiz Aug 6, 2025
f7ec158
chore: Fix typing issues and few PR suggestions
kmruiz Aug 6, 2025
f49e2b0
chore: style changes, use a getter for isConnectedToMongoDB
kmruiz Aug 6, 2025
e03b7a6
chore: move AtlasConnectionInfo to the connection manager
kmruiz Aug 6, 2025
43b54a0
chore: fixed linting issues
kmruiz Aug 6, 2025
89ba76d
chore: emit the close event
kmruiz Aug 6, 2025
61d7b1e
chore: add resource subscriptions and improve the resource prompt
kmruiz Aug 6, 2025
a33abb3
chore: Do not use anonymous tuples
kmruiz Aug 6, 2025
cc92766
chore: change the break to a return to make it easier to follow
kmruiz Aug 6, 2025
0a77b97
chore: small refactor
kmruiz Aug 6, 2025
6b5f49b
chore: minor clean up of redundant params
kmruiz Aug 6, 2025
6cbb1ab
Merge branch 'main' into chore/mcp-81
kmruiz Aug 6, 2025
1122e09
chore: ensure that we return connecting when not connected yet
kmruiz Aug 6, 2025
da872b4
chore: allow having the atlas cluster info independently of the conne…
kmruiz Aug 6, 2025
fe16acd
chore: simplify query connection logic
kmruiz Aug 6, 2025
17494ac
chore: This is clearer on how the behavior should look like
kmruiz Aug 6, 2025
693c755
chore: finish the refactor and clean up
kmruiz Aug 6, 2025
a6efc47
chore: propagate connected atlas cluster when disconnecting
kmruiz Aug 6, 2025
71dbf56
chore: fix linter issues and some status mismatch
kmruiz Aug 6, 2025
76e8ab5
chore: clean up atlas resource handling
kmruiz Aug 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions src/common/connectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ export interface ConnectionStateErrored extends ConnectionState {
}

export type AnyConnectionState =
| ConnectionState
| ConnectionStateConnected
| ConnectionStateConnecting
| ConnectionStateDisconnected
Expand All @@ -77,7 +76,7 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
async connect(settings: ConnectionSettings): Promise<AnyConnectionState> {
this.emit("connection-requested", this.state);

if (this.state.tag == "connected" || this.state.tag == "connecting") {
if (this.state.tag === "connected" || this.state.tag === "connecting") {
await this.disconnect();
}

Expand Down Expand Up @@ -126,18 +125,13 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
}

async disconnect(): Promise<ConnectionStateDisconnected | ConnectionStateErrored> {
if (this.state.tag == "disconnected") {
return this.state as ConnectionStateDisconnected;
if (this.state.tag === "disconnected" || this.state.tag === "errored") {
return this.state;
}

if (this.state.tag == "errored") {
return this.state as ConnectionStateErrored;
}

if (this.state.tag == "connected" || this.state.tag == "connecting") {
const state = this.state as ConnectionStateConnecting | ConnectionStateConnected;
if (this.state.tag === "connected" || this.state.tag === "connecting") {
try {
await state.serviceProvider?.close(true);
await this.state.serviceProvider?.close(true);
} finally {
this.changeState("connection-closed", { tag: "disconnected" });
}
Expand All @@ -150,9 +144,14 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
return this.state;
}

changeState<State extends AnyConnectionState>(event: keyof ConnectionManagerEvents, newState: State): State {
changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
event: Event,
newState: State
): State {
this.state = newState;
this.emit(event, newState);
// TypeScript doesn't seem to be happy with the spread operator and generics
// eslint-disable-next-line
this.emit(event, ...([newState] as any));
return newState;
}

Expand All @@ -169,7 +168,7 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
case "GSSAPI":
return "kerberos";
case "PLAIN":
if (searchParams.get("authSource") == "$external") {
if (searchParams.get("authSource") === "$external") {
return "ldap";
}
break;
Expand Down
37 changes: 22 additions & 15 deletions src/common/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Implementation } from "@modelcontextprotocol/sdk/types.js";
import logger, { LogId } from "./logger.js";
import EventEmitter from "events";
import {
AnyConnectionState,
AtlasClusterConnectionInfo,
ConnectionManager,
ConnectionSettings,
ConnectionStateConnected,
Expand Down Expand Up @@ -33,12 +33,6 @@ export class Session extends EventEmitter<SessionEvents> {
name: string;
version: string;
};
connectedAtlasCluster?: {
username: string;
projectId: string;
clusterName: string;
expiryDate: Date;
};

constructor({ apiBaseUrl, apiClientId, apiClientSecret, connectionManager }: SessionOptions) {
super();
Expand Down Expand Up @@ -70,20 +64,24 @@ export class Session extends EventEmitter<SessionEvents> {
}

async disconnect(): Promise<void> {
const currentConnection = this.connectionManager.currentConnectionState;
const atlasCluster =
currentConnection.tag === "connected" ? currentConnection.connectedAtlasCluster : undefined;

try {
await this.connectionManager.disconnect();
} catch (err: unknown) {
const error = err instanceof Error ? err : new Error(String(err));
logger.error(LogId.mongodbDisconnectFailure, "Error closing service provider:", error.message);
}

if (this.connectedAtlasCluster?.username && this.connectedAtlasCluster?.projectId) {
if (atlasCluster?.username && atlasCluster?.projectId) {
void this.apiClient
.deleteDatabaseUser({
params: {
path: {
groupId: this.connectedAtlasCluster.projectId,
username: this.connectedAtlasCluster.username,
groupId: atlasCluster.projectId,
username: atlasCluster.username,
databaseName: "admin",
},
},
Expand All @@ -96,35 +94,44 @@ export class Session extends EventEmitter<SessionEvents> {
`Error deleting previous database user: ${error.message}`
);
});
this.connectedAtlasCluster = undefined;
}
}

async close(): Promise<void> {
await this.disconnect();
await this.apiClient.close();
this.emit("close");
}

async connectToMongoDB(settings: ConnectionSettings): Promise<AnyConnectionState> {
async connectToMongoDB(settings: ConnectionSettings): Promise<void> {
try {
return await this.connectionManager.connect({ ...settings });
await this.connectionManager.connect({ ...settings });
} catch (error: unknown) {
const message = error instanceof Error ? error.message : (error as string);
this.emit("connection-error", message);
throw error;
}
}

isConnectedToMongoDB(): boolean {
get isConnectedToMongoDB(): boolean {
return this.connectionManager.currentConnectionState.tag === "connected";
}

get serviceProvider(): NodeDriverServiceProvider {
if (this.isConnectedToMongoDB()) {
if (this.isConnectedToMongoDB) {
const state = this.connectionManager.currentConnectionState as ConnectionStateConnected;
return state.serviceProvider;
}

throw new MongoDBError(ErrorCodes.NotConnectedToMongoDB, "Not connected to MongoDB");
}

get connectedAtlasCluster(): AtlasClusterConnectionInfo | undefined {
const connectionState = this.connectionManager.currentConnectionState;
if (connectionState.tag === "connected") {
return connectionState.connectedAtlasCluster;
}

return undefined;
}
}
44 changes: 27 additions & 17 deletions src/tools/atlas/connect/connectCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { generateSecurePassword } from "../../../helpers/generatePassword.js";
import logger, { LogId } from "../../../common/logger.js";
import { inspectCluster } from "../../../common/atlas/cluster.js";
import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js";
import { AtlasClusterConnectionInfo } from "../../../common/connectionManager.js";

const EXPIRY_MS = 1000 * 60 * 60 * 12; // 12 hours

Expand All @@ -27,7 +28,7 @@ export class ConnectClusterTool extends AtlasToolBase {
clusterName: string
): Promise<"connected" | "disconnected" | "connecting" | "connected-to-other-cluster" | "unknown"> {
if (!this.session.connectedAtlasCluster) {
if (this.session.isConnectedToMongoDB()) {
if (this.session.isConnectedToMongoDB) {
return "connected-to-other-cluster";
}
return "disconnected";
Expand All @@ -40,7 +41,7 @@ export class ConnectClusterTool extends AtlasToolBase {
return "connected-to-other-cluster";
}

if (!this.session.isConnectedToMongoDB()) {
if (!this.session.isConnectedToMongoDB) {
return "connecting";
}

Expand All @@ -61,7 +62,10 @@ export class ConnectClusterTool extends AtlasToolBase {
}
}

private async prepareClusterConnection(projectId: string, clusterName: string): Promise<string> {
private async prepareClusterConnection(
projectId: string,
clusterName: string
): Promise<[string, AtlasClusterConnectionInfo]> {
const cluster = await inspectCluster(this.session.apiClient, projectId, clusterName);

if (!cluster.connectionString) {
Expand Down Expand Up @@ -109,7 +113,7 @@ export class ConnectClusterTool extends AtlasToolBase {
},
});

this.session.connectedAtlasCluster = {
const connectedAtlasCluster = {
username,
projectId,
clusterName,
Expand All @@ -120,10 +124,15 @@ export class ConnectClusterTool extends AtlasToolBase {
cn.username = username;
cn.password = password;
cn.searchParams.set("authSource", "admin");
return cn.toString();
return [cn.toString(), connectedAtlasCluster];
}

private async connectToCluster(projectId: string, clusterName: string, connectionString: string): Promise<void> {
private async connectToCluster(
projectId: string,
clusterName: string,
connectionString: string,
atlas: AtlasClusterConnectionInfo
): Promise<void> {
let lastError: Error | undefined = undefined;

logger.debug(
Expand All @@ -145,7 +154,7 @@ export class ConnectClusterTool extends AtlasToolBase {
try {
lastError = undefined;

await this.session.connectToMongoDB({ connectionString, ...this.config.connectOptions });
await this.session.connectToMongoDB({ connectionString, ...this.config.connectOptions, atlas });
break;
} catch (err: unknown) {
const error = err instanceof Error ? err : new Error(String(err));
Expand Down Expand Up @@ -187,7 +196,6 @@ export class ConnectClusterTool extends AtlasToolBase {
);
});
}
this.session.connectedAtlasCluster = undefined;
throw lastError;
}

Expand Down Expand Up @@ -221,17 +229,19 @@ export class ConnectClusterTool extends AtlasToolBase {
case "disconnected":
default: {
await this.session.disconnect();
const connectionString = await this.prepareClusterConnection(projectId, clusterName);
const [connectionString, atlas] = await this.prepareClusterConnection(projectId, clusterName);

// try to connect for about 5 minutes asynchronously
void this.connectToCluster(projectId, clusterName, connectionString).catch((err: unknown) => {
const error = err instanceof Error ? err : new Error(String(err));
logger.error(
LogId.atlasConnectFailure,
"atlas-connect-cluster",
`error connecting to cluster: ${error.message}`
);
});
void this.connectToCluster(projectId, clusterName, connectionString, atlas).catch(
(err: unknown) => {
const error = err instanceof Error ? err : new Error(String(err));
logger.error(
LogId.atlasConnectFailure,
"atlas-connect-cluster",
`error connecting to cluster: ${error.message}`
);
}
);
break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/tools/mongodb/connect/connect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ export class ConnectTool extends MongoDBToolBase {
}

private updateMetadata(): void {
if (this.session.isConnectedToMongoDB()) {
if (this.session.isConnectedToMongoDB) {
this.update?.({
name: connectedName,
description: connectedDescription,
Expand Down
7 changes: 3 additions & 4 deletions src/tools/mongodb/mongodbTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { ErrorCodes, MongoDBError } from "../../common/errors.js";
import logger, { LogId } from "../../common/logger.js";
import { Server } from "../../server.js";
import { AnyConnectionState } from "../../common/connectionManager.js";

export const DbOperationArgs = {
database: z.string().describe("Database name"),
Expand All @@ -17,7 +16,7 @@ export abstract class MongoDBToolBase extends ToolBase {
public category: ToolCategory = "mongodb";

protected async ensureConnected(): Promise<NodeDriverServiceProvider> {
if (!this.session.isConnectedToMongoDB()) {
if (!this.session.isConnectedToMongoDB) {
if (this.session.connectedAtlasCluster) {
throw new MongoDBError(
ErrorCodes.NotConnectedToMongoDB,
Expand All @@ -39,7 +38,7 @@ export abstract class MongoDBToolBase extends ToolBase {
}
}

if (!this.session.isConnectedToMongoDB()) {
if (!this.session.isConnectedToMongoDB) {
throw new MongoDBError(ErrorCodes.NotConnectedToMongoDB, "Not connected to MongoDB");
}

Expand Down Expand Up @@ -117,7 +116,7 @@ export abstract class MongoDBToolBase extends ToolBase {
return super.handleError(error, args);
}

protected connectToMongoDB(connectionString: string): Promise<AnyConnectionState> {
protected connectToMongoDB(connectionString: string): Promise<void> {
return this.session.connectToMongoDB({ connectionString, ...this.config.connectOptions });
}

Expand Down
12 changes: 12 additions & 0 deletions tests/integration/common/connectionManager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ describeWithMongoDB("Connection Manager", (integration) => {
it("should notify that it was disconnected before connecting", () => {
expect(connectionManagerSpies["connection-closed"]).toHaveBeenCalled();
});

it("should be marked explicitly as disconnected", () => {
expect(connectionManager().currentConnectionState.tag).toEqual("disconnected");
});
});

describe("when reconnects", () => {
Expand All @@ -95,6 +99,10 @@ describeWithMongoDB("Connection Manager", (integration) => {
it("should notify that it was connected again", () => {
expect(connectionManagerSpies["connection-succeeded"]).toHaveBeenCalled();
});

it("should be marked explicitly as connected", () => {
expect(connectionManager().currentConnectionState.tag).toEqual("connected");
});
});

describe("when fails to connect to a new cluster", () => {
Expand All @@ -116,6 +124,10 @@ describeWithMongoDB("Connection Manager", (integration) => {
it("should notify that it failed connecting", () => {
expect(connectionManagerSpies["connection-errored"]).toHaveBeenCalled();
});

it("should be marked explicitly as connected", () => {
expect(connectionManager().currentConnectionState.tag).toEqual("errored");
});
});
});

Expand Down
8 changes: 4 additions & 4 deletions tests/integration/tools/mongodb/connect/connect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import { beforeEach, describe, expect, it } from "vitest";
describeWithMongoDB(
"SwitchConnection tool",
(integration) => {
beforeEach(() => {
integration.mcpServer().userConfig.connectionString = integration.connectionString();
integration.mcpServer().session.connectionManager.changeState("connection-succeeded", {
tag: "connected",
beforeEach(async () => {
await integration.mcpServer().session.connectToMongoDB({
connectionString: integration.connectionString(),
...config.connectOptions,
});
});

Expand Down
9 changes: 1 addition & 8 deletions tests/integration/transports/stdio.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import { describe, expect, beforeEach, it, beforeAll, afterAll } from "vitest";
import { describe, expect, it, beforeAll, afterAll } from "vitest";
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { describeWithMongoDB } from "../tools/mongodb/mongodbHelpers.js";

describeWithMongoDB("StdioRunner", (integration) => {
beforeEach(() => {
integration.mcpServer().userConfig.connectionString = integration.connectionString();
integration.mcpServer().session.connectionManager.changeState("connection-succeeded", {
tag: "connected",
});
});

describe("client connects successfully", () => {
let client: Client;
let transport: StdioClientTransport;
Expand Down
Loading