Skip to content
54 changes: 54 additions & 0 deletions src/common/atlas/accessListUtils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { ApiClient } from "./apiClient.js";
import logger, { LogId } from "../logger.js";
import { ApiClientError } from "./apiClientError.js";

export const DEFAULT_ACCESS_LIST_COMMENT = "Added by MongoDB MCP Server to enable tool access";

export async function makeCurrentIpAccessListEntry(
apiClient: ApiClient,
projectId: string,
comment: string = DEFAULT_ACCESS_LIST_COMMENT
) {
const { currentIpv4Address } = await apiClient.getIpInfo();
return {
groupId: projectId,
ipAddress: currentIpv4Address,
comment,
};
}

/**
* Ensures the current public IP is in the access list for the given Atlas project.
* If the IP is already present, this is a no-op.
* @param apiClient The Atlas API client instance
* @param projectId The Atlas project ID
*/
export async function ensureCurrentIpInAccessList(apiClient: ApiClient, projectId: string): Promise<void> {
const entry = await makeCurrentIpAccessListEntry(apiClient, projectId, DEFAULT_ACCESS_LIST_COMMENT);
try {
await apiClient.createProjectIpAccessList({
params: { path: { groupId: projectId } },
body: [entry],
});
logger.debug(
LogId.atlasIpAccessListAdded,
"accessListUtils",
`IP access list created: ${JSON.stringify(entry)}`
);
} catch (err) {
if (err instanceof ApiClientError && err.response?.status === 409) {
// 409 Conflict: entry already exists, log info
logger.debug(
LogId.atlasIpAccessListAdded,
"accessListUtils",
`IP address ${entry.ipAddress} is already present in the access list for project ${projectId}.`
);
return;
}
logger.warning(
LogId.atlasIpAccessListAddFailure,
"accessListUtils",
`Error adding IP access list: ${err instanceof Error ? err.message : String(err)}`
);
}
}
2 changes: 2 additions & 0 deletions src/common/logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ export const LogId = {
atlasConnectAttempt: mongoLogId(1_001_005),
atlasConnectSucceeded: mongoLogId(1_001_006),
atlasApiRevokeFailure: mongoLogId(1_001_007),
atlasIpAccessListAdded: mongoLogId(1_001_008),
atlasIpAccessListAddFailure: mongoLogId(1_001_009),

telemetryDisabled: mongoLogId(1_002_001),
telemetryEmitFailure: mongoLogId(1_002_002),
Expand Down
2 changes: 2 additions & 0 deletions src/tools/atlas/connect/connectCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { ToolArgs, OperationType } from "../../tool.js";
import { generateSecurePassword } from "../../../helpers/generatePassword.js";
import logger, { LogId } from "../../../common/logger.js";
import { inspectCluster } from "../../../common/atlas/cluster.js";
import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js";

const EXPIRY_MS = 1000 * 60 * 60 * 12; // 12 hours

Expand Down Expand Up @@ -198,6 +199,7 @@ export class ConnectClusterTool extends AtlasToolBase {
}

protected async execute({ projectId, clusterName }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
await ensureCurrentIpInAccessList(this.session.apiClient, projectId);
for (let i = 0; i < 60; i++) {
const state = await this.queryConnection(projectId, clusterName);
switch (state) {
Expand Down
24 changes: 13 additions & 11 deletions src/tools/atlas/create/createAccessList.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ import { z } from "zod";
import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { AtlasToolBase } from "../atlasTool.js";
import { ToolArgs, OperationType } from "../../tool.js";

const DEFAULT_COMMENT = "Added by Atlas MCP";
import { makeCurrentIpAccessListEntry, DEFAULT_ACCESS_LIST_COMMENT } from "../../../common/atlas/accessListUtils.js";

export class CreateAccessListTool extends AtlasToolBase {
public name = "atlas-create-access-list";
Expand All @@ -17,7 +16,11 @@ export class CreateAccessListTool extends AtlasToolBase {
.optional(),
cidrBlocks: z.array(z.string().cidr()).describe("CIDR blocks to allow access from").optional(),
currentIpAddress: z.boolean().describe("Add the current IP address").default(false),
comment: z.string().describe("Comment for the access list entries").default(DEFAULT_COMMENT).optional(),
comment: z
.string()
.describe("Comment for the access list entries")
.default(DEFAULT_ACCESS_LIST_COMMENT)
.optional(),
};

protected async execute({
Expand All @@ -34,23 +37,22 @@ export class CreateAccessListTool extends AtlasToolBase {
const ipInputs = (ipAddresses || []).map((ipAddress) => ({
groupId: projectId,
ipAddress,
comment: comment || DEFAULT_COMMENT,
comment: comment || DEFAULT_ACCESS_LIST_COMMENT,
}));

if (currentIpAddress) {
const currentIp = await this.session.apiClient.getIpInfo();
const input = {
groupId: projectId,
ipAddress: currentIp.currentIpv4Address,
comment: comment || DEFAULT_COMMENT,
};
const input = await makeCurrentIpAccessListEntry(
this.session.apiClient,
projectId,
comment || DEFAULT_ACCESS_LIST_COMMENT
);
ipInputs.push(input);
}

const cidrInputs = (cidrBlocks || []).map((cidrBlock) => ({
groupId: projectId,
cidrBlock,
comment: comment || DEFAULT_COMMENT,
comment: comment || DEFAULT_ACCESS_LIST_COMMENT,
}));

const inputs = [...ipInputs, ...cidrInputs];
Expand Down
2 changes: 2 additions & 0 deletions src/tools/atlas/create/createDBUser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { AtlasToolBase } from "../atlasTool.js";
import { ToolArgs, OperationType } from "../../tool.js";
import { CloudDatabaseUser, DatabaseUserRole } from "../../../common/atlas/openapi.js";
import { generateSecurePassword } from "../../../helpers/generatePassword.js";
import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js";

export class CreateDBUserTool extends AtlasToolBase {
public name = "atlas-create-db-user";
Expand Down Expand Up @@ -44,6 +45,7 @@ export class CreateDBUserTool extends AtlasToolBase {
roles,
clusters,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
await ensureCurrentIpInAccessList(this.session.apiClient, projectId);
const shouldGeneratePassword = !password;
if (shouldGeneratePassword) {
password = await generateSecurePassword();
Expand Down
2 changes: 2 additions & 0 deletions src/tools/atlas/create/createFreeCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { AtlasToolBase } from "../atlasTool.js";
import { ToolArgs, OperationType } from "../../tool.js";
import { ClusterDescription20240805 } from "../../../common/atlas/openapi.js";
import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js";

export class CreateFreeClusterTool extends AtlasToolBase {
public name = "atlas-create-free-cluster";
Expand Down Expand Up @@ -37,6 +38,7 @@ export class CreateFreeClusterTool extends AtlasToolBase {
terminationProtectionEnabled: false,
} as unknown as ClusterDescription20240805;

await ensureCurrentIpInAccessList(this.session.apiClient, projectId);
await this.session.apiClient.createCluster({
params: {
path: {
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/tools/atlas/accessLists.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { describeWithAtlas, withProject } from "./atlasHelpers.js";
import { expectDefined } from "../../helpers.js";
import { afterAll, beforeAll, describe, expect, it } from "vitest";
import { ensureCurrentIpInAccessList } from "../../../../src/common/atlas/accessListUtils.js";

function generateRandomIp() {
const randomIp: number[] = [192];
Expand Down Expand Up @@ -95,5 +96,23 @@ describeWithAtlas("ip access lists", (integration) => {
}
});
});

describe("ensureCurrentIpInAccessList helper", () => {
it("should add the current IP to the access list and be idempotent", async () => {
const apiClient = integration.mcpServer().session.apiClient;
const projectId = getProjectId();
const ipInfo = await apiClient.getIpInfo();
// First call should add the IP
await expect(ensureCurrentIpInAccessList(apiClient, projectId)).resolves.not.toThrow();
// Second call should be a no-op (idempotent)
await expect(ensureCurrentIpInAccessList(apiClient, projectId)).resolves.not.toThrow();
// Check that the IP is present in the access list
const accessList = await apiClient.listProjectIpAccessLists({
params: { path: { groupId: projectId } },
});
const found = accessList.results?.some((entry) => entry.ipAddress === ipInfo.currentIpv4Address);
expect(found).toBe(true);
});
});
});
});
11 changes: 10 additions & 1 deletion tests/integration/tools/atlas/clusters.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ describeWithAtlas("clusters", (integration) => {
expect(createFreeCluster.inputSchema.properties).toHaveProperty("region");
});

it("should create a free cluster", async () => {
it("should create a free cluster and add current IP to access list", async () => {
const projectId = getProjectId();
const session = integration.mcpServer().session;
const ipInfo = await session.apiClient.getIpInfo();

const response = (await integration.mcpClient().callTool({
name: "atlas-create-free-cluster",
Expand All @@ -96,6 +98,13 @@ describeWithAtlas("clusters", (integration) => {
expect(response.content).toBeInstanceOf(Array);
expect(response.content).toHaveLength(2);
expect(response.content[0]?.text).toContain("has been created");

// Check that the current IP is present in the access list
const accessList = await session.apiClient.listProjectIpAccessLists({
params: { path: { groupId: projectId } },
});
const found = accessList.results?.some((entry) => entry.ipAddress === ipInfo.currentIpv4Address);
expect(found).toBe(true);
});
});

Expand Down
12 changes: 12 additions & 0 deletions tests/integration/tools/atlas/dbUsers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ describeWithAtlas("db users", (integration) => {
expect(elements[0]?.text).toContain(userName);
expect(elements[0]?.text).toContain("with password: `");
});

it("should add current IP to access list when creating a database user", async () => {
const projectId = getProjectId();
const session = integration.mcpServer().session;
const ipInfo = await session.apiClient.getIpInfo();
await createUserWithMCP();
const accessList = await session.apiClient.listProjectIpAccessLists({
params: { path: { groupId: projectId } },
});
const found = accessList.results?.some((entry) => entry.ipAddress === ipInfo.currentIpv4Address);
expect(found).toBe(true);
});
});
describe("atlas-list-db-users", () => {
it("should have correct metadata", async () => {
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/accessListUtils.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { describe, it, expect, vi } from "vitest";
import { ApiClient } from "../../src/common/atlas/apiClient.js";
import { ensureCurrentIpInAccessList, DEFAULT_ACCESS_LIST_COMMENT } from "../../src/common/atlas/accessListUtils.js";
import { ApiClientError } from "../../src/common/atlas/apiClientError.js";

describe("accessListUtils", () => {
it("should add the current IP to the access list", async () => {
const apiClient = {
getIpInfo: vi.fn().mockResolvedValue({ currentIpv4Address: "127.0.0.1" } as never),
createProjectIpAccessList: vi.fn().mockResolvedValue(undefined as never),
} as unknown as ApiClient;
await ensureCurrentIpInAccessList(apiClient, "projectId");
// eslint-disable-next-line @typescript-eslint/unbound-method
expect(apiClient.createProjectIpAccessList).toHaveBeenCalledWith({
params: { path: { groupId: "projectId" } },
body: [{ groupId: "projectId", ipAddress: "127.0.0.1", comment: DEFAULT_ACCESS_LIST_COMMENT }],
});
});

it("should not fail if the current IP is already in the access list", async () => {
const apiClient = {
getIpInfo: vi.fn().mockResolvedValue({ currentIpv4Address: "127.0.0.1" } as never),
createProjectIpAccessList: vi
.fn()
.mockRejectedValue(
ApiClientError.fromError(
{ status: 409, statusText: "Conflict" } as Response,
{ message: "Conflict" } as never
) as never
),
} as unknown as ApiClient;
await ensureCurrentIpInAccessList(apiClient, "projectId");
// eslint-disable-next-line @typescript-eslint/unbound-method
expect(apiClient.createProjectIpAccessList).toHaveBeenCalledWith({
params: { path: { groupId: "projectId" } },
body: [{ groupId: "projectId", ipAddress: "127.0.0.1", comment: DEFAULT_ACCESS_LIST_COMMENT }],
});
});
});
Loading