diff --git a/src/common/atlas/accessListUtils.ts b/src/common/atlas/accessListUtils.ts new file mode 100644 index 00000000..dddb8605 --- /dev/null +++ b/src/common/atlas/accessListUtils.ts @@ -0,0 +1,54 @@ +import { ApiClient } from "./apiClient.js"; +import logger, { LogId } from "../logger.js"; +import { ApiClientError } from "./apiClientError.js"; + +export const DEFAULT_ACCESS_LIST_COMMENT = "Added by MongoDB MCP Server to enable tool access"; + +export async function makeCurrentIpAccessListEntry( + apiClient: ApiClient, + projectId: string, + comment: string = DEFAULT_ACCESS_LIST_COMMENT +) { + const { currentIpv4Address } = await apiClient.getIpInfo(); + return { + groupId: projectId, + ipAddress: currentIpv4Address, + comment, + }; +} + +/** + * Ensures the current public IP is in the access list for the given Atlas project. + * If the IP is already present, this is a no-op. + * @param apiClient The Atlas API client instance + * @param projectId The Atlas project ID + */ +export async function ensureCurrentIpInAccessList(apiClient: ApiClient, projectId: string): Promise { + const entry = await makeCurrentIpAccessListEntry(apiClient, projectId, DEFAULT_ACCESS_LIST_COMMENT); + try { + await apiClient.createProjectIpAccessList({ + params: { path: { groupId: projectId } }, + body: [entry], + }); + logger.debug( + LogId.atlasIpAccessListAdded, + "accessListUtils", + `IP access list created: ${JSON.stringify(entry)}` + ); + } catch (err) { + if (err instanceof ApiClientError && err.response?.status === 409) { + // 409 Conflict: entry already exists, log info + logger.debug( + LogId.atlasIpAccessListAdded, + "accessListUtils", + `IP address ${entry.ipAddress} is already present in the access list for project ${projectId}.` + ); + return; + } + logger.warning( + LogId.atlasIpAccessListAddFailure, + "accessListUtils", + `Error adding IP access list: ${err instanceof Error ? err.message : String(err)}` + ); + } +} diff --git a/src/common/logger.ts b/src/common/logger.ts index b1fb78a9..fc89f6bd 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -20,6 +20,8 @@ export const LogId = { atlasConnectAttempt: mongoLogId(1_001_005), atlasConnectSucceeded: mongoLogId(1_001_006), atlasApiRevokeFailure: mongoLogId(1_001_007), + atlasIpAccessListAdded: mongoLogId(1_001_008), + atlasIpAccessListAddFailure: mongoLogId(1_001_009), telemetryDisabled: mongoLogId(1_002_001), telemetryEmitFailure: mongoLogId(1_002_002), diff --git a/src/tools/atlas/connect/connectCluster.ts b/src/tools/atlas/connect/connectCluster.ts index d505dfed..e83c3040 100644 --- a/src/tools/atlas/connect/connectCluster.ts +++ b/src/tools/atlas/connect/connectCluster.ts @@ -5,6 +5,7 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { generateSecurePassword } from "../../../helpers/generatePassword.js"; import logger, { LogId } from "../../../common/logger.js"; import { inspectCluster } from "../../../common/atlas/cluster.js"; +import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js"; const EXPIRY_MS = 1000 * 60 * 60 * 12; // 12 hours @@ -198,6 +199,7 @@ export class ConnectClusterTool extends AtlasToolBase { } protected async execute({ projectId, clusterName }: ToolArgs): Promise { + await ensureCurrentIpInAccessList(this.session.apiClient, projectId); for (let i = 0; i < 60; i++) { const state = await this.queryConnection(projectId, clusterName); switch (state) { diff --git a/src/tools/atlas/create/createAccessList.ts b/src/tools/atlas/create/createAccessList.ts index 4941b1e8..3a2c1a22 100644 --- a/src/tools/atlas/create/createAccessList.ts +++ b/src/tools/atlas/create/createAccessList.ts @@ -2,8 +2,7 @@ import { z } from "zod"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { AtlasToolBase } from "../atlasTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; - -const DEFAULT_COMMENT = "Added by Atlas MCP"; +import { makeCurrentIpAccessListEntry, DEFAULT_ACCESS_LIST_COMMENT } from "../../../common/atlas/accessListUtils.js"; export class CreateAccessListTool extends AtlasToolBase { public name = "atlas-create-access-list"; @@ -17,7 +16,11 @@ export class CreateAccessListTool extends AtlasToolBase { .optional(), cidrBlocks: z.array(z.string().cidr()).describe("CIDR blocks to allow access from").optional(), currentIpAddress: z.boolean().describe("Add the current IP address").default(false), - comment: z.string().describe("Comment for the access list entries").default(DEFAULT_COMMENT).optional(), + comment: z + .string() + .describe("Comment for the access list entries") + .default(DEFAULT_ACCESS_LIST_COMMENT) + .optional(), }; protected async execute({ @@ -34,23 +37,22 @@ export class CreateAccessListTool extends AtlasToolBase { const ipInputs = (ipAddresses || []).map((ipAddress) => ({ groupId: projectId, ipAddress, - comment: comment || DEFAULT_COMMENT, + comment: comment || DEFAULT_ACCESS_LIST_COMMENT, })); if (currentIpAddress) { - const currentIp = await this.session.apiClient.getIpInfo(); - const input = { - groupId: projectId, - ipAddress: currentIp.currentIpv4Address, - comment: comment || DEFAULT_COMMENT, - }; + const input = await makeCurrentIpAccessListEntry( + this.session.apiClient, + projectId, + comment || DEFAULT_ACCESS_LIST_COMMENT + ); ipInputs.push(input); } const cidrInputs = (cidrBlocks || []).map((cidrBlock) => ({ groupId: projectId, cidrBlock, - comment: comment || DEFAULT_COMMENT, + comment: comment || DEFAULT_ACCESS_LIST_COMMENT, })); const inputs = [...ipInputs, ...cidrInputs]; diff --git a/src/tools/atlas/create/createDBUser.ts b/src/tools/atlas/create/createDBUser.ts index d2133a04..9541c281 100644 --- a/src/tools/atlas/create/createDBUser.ts +++ b/src/tools/atlas/create/createDBUser.ts @@ -4,6 +4,7 @@ import { AtlasToolBase } from "../atlasTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; import { CloudDatabaseUser, DatabaseUserRole } from "../../../common/atlas/openapi.js"; import { generateSecurePassword } from "../../../helpers/generatePassword.js"; +import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js"; export class CreateDBUserTool extends AtlasToolBase { public name = "atlas-create-db-user"; @@ -44,6 +45,7 @@ export class CreateDBUserTool extends AtlasToolBase { roles, clusters, }: ToolArgs): Promise { + await ensureCurrentIpInAccessList(this.session.apiClient, projectId); const shouldGeneratePassword = !password; if (shouldGeneratePassword) { password = await generateSecurePassword(); diff --git a/src/tools/atlas/create/createFreeCluster.ts b/src/tools/atlas/create/createFreeCluster.ts index ed04409b..0a8dda09 100644 --- a/src/tools/atlas/create/createFreeCluster.ts +++ b/src/tools/atlas/create/createFreeCluster.ts @@ -3,6 +3,7 @@ import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { AtlasToolBase } from "../atlasTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; import { ClusterDescription20240805 } from "../../../common/atlas/openapi.js"; +import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js"; export class CreateFreeClusterTool extends AtlasToolBase { public name = "atlas-create-free-cluster"; @@ -37,6 +38,7 @@ export class CreateFreeClusterTool extends AtlasToolBase { terminationProtectionEnabled: false, } as unknown as ClusterDescription20240805; + await ensureCurrentIpInAccessList(this.session.apiClient, projectId); await this.session.apiClient.createCluster({ params: { path: { diff --git a/tests/integration/tools/atlas/accessLists.test.ts b/tests/integration/tools/atlas/accessLists.test.ts index d5ab2916..8274de18 100644 --- a/tests/integration/tools/atlas/accessLists.test.ts +++ b/tests/integration/tools/atlas/accessLists.test.ts @@ -2,6 +2,7 @@ import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { describeWithAtlas, withProject } from "./atlasHelpers.js"; import { expectDefined } from "../../helpers.js"; import { afterAll, beforeAll, describe, expect, it } from "vitest"; +import { ensureCurrentIpInAccessList } from "../../../../src/common/atlas/accessListUtils.js"; function generateRandomIp() { const randomIp: number[] = [192]; @@ -95,5 +96,23 @@ describeWithAtlas("ip access lists", (integration) => { } }); }); + + describe("ensureCurrentIpInAccessList helper", () => { + it("should add the current IP to the access list and be idempotent", async () => { + const apiClient = integration.mcpServer().session.apiClient; + const projectId = getProjectId(); + const ipInfo = await apiClient.getIpInfo(); + // First call should add the IP + await expect(ensureCurrentIpInAccessList(apiClient, projectId)).resolves.not.toThrow(); + // Second call should be a no-op (idempotent) + await expect(ensureCurrentIpInAccessList(apiClient, projectId)).resolves.not.toThrow(); + // Check that the IP is present in the access list + const accessList = await apiClient.listProjectIpAccessLists({ + params: { path: { groupId: projectId } }, + }); + const found = accessList.results?.some((entry) => entry.ipAddress === ipInfo.currentIpv4Address); + expect(found).toBe(true); + }); + }); }); }); diff --git a/tests/integration/tools/atlas/clusters.test.ts b/tests/integration/tools/atlas/clusters.test.ts index ed6d0dd8..5cc3c1f6 100644 --- a/tests/integration/tools/atlas/clusters.test.ts +++ b/tests/integration/tools/atlas/clusters.test.ts @@ -82,8 +82,10 @@ describeWithAtlas("clusters", (integration) => { expect(createFreeCluster.inputSchema.properties).toHaveProperty("region"); }); - it("should create a free cluster", async () => { + it("should create a free cluster and add current IP to access list", async () => { const projectId = getProjectId(); + const session = integration.mcpServer().session; + const ipInfo = await session.apiClient.getIpInfo(); const response = (await integration.mcpClient().callTool({ name: "atlas-create-free-cluster", @@ -96,6 +98,13 @@ describeWithAtlas("clusters", (integration) => { expect(response.content).toBeInstanceOf(Array); expect(response.content).toHaveLength(2); expect(response.content[0]?.text).toContain("has been created"); + + // Check that the current IP is present in the access list + const accessList = await session.apiClient.listProjectIpAccessLists({ + params: { path: { groupId: projectId } }, + }); + const found = accessList.results?.some((entry) => entry.ipAddress === ipInfo.currentIpv4Address); + expect(found).toBe(true); }); }); diff --git a/tests/integration/tools/atlas/dbUsers.test.ts b/tests/integration/tools/atlas/dbUsers.test.ts index 05d0a098..387733a5 100644 --- a/tests/integration/tools/atlas/dbUsers.test.ts +++ b/tests/integration/tools/atlas/dbUsers.test.ts @@ -79,6 +79,18 @@ describeWithAtlas("db users", (integration) => { expect(elements[0]?.text).toContain(userName); expect(elements[0]?.text).toContain("with password: `"); }); + + it("should add current IP to access list when creating a database user", async () => { + const projectId = getProjectId(); + const session = integration.mcpServer().session; + const ipInfo = await session.apiClient.getIpInfo(); + await createUserWithMCP(); + const accessList = await session.apiClient.listProjectIpAccessLists({ + params: { path: { groupId: projectId } }, + }); + const found = accessList.results?.some((entry) => entry.ipAddress === ipInfo.currentIpv4Address); + expect(found).toBe(true); + }); }); describe("atlas-list-db-users", () => { it("should have correct metadata", async () => { diff --git a/tests/unit/accessListUtils.test.ts b/tests/unit/accessListUtils.test.ts new file mode 100644 index 00000000..6dc62b65 --- /dev/null +++ b/tests/unit/accessListUtils.test.ts @@ -0,0 +1,39 @@ +import { describe, it, expect, vi } from "vitest"; +import { ApiClient } from "../../src/common/atlas/apiClient.js"; +import { ensureCurrentIpInAccessList, DEFAULT_ACCESS_LIST_COMMENT } from "../../src/common/atlas/accessListUtils.js"; +import { ApiClientError } from "../../src/common/atlas/apiClientError.js"; + +describe("accessListUtils", () => { + it("should add the current IP to the access list", async () => { + const apiClient = { + getIpInfo: vi.fn().mockResolvedValue({ currentIpv4Address: "127.0.0.1" } as never), + createProjectIpAccessList: vi.fn().mockResolvedValue(undefined as never), + } as unknown as ApiClient; + await ensureCurrentIpInAccessList(apiClient, "projectId"); + // eslint-disable-next-line @typescript-eslint/unbound-method + expect(apiClient.createProjectIpAccessList).toHaveBeenCalledWith({ + params: { path: { groupId: "projectId" } }, + body: [{ groupId: "projectId", ipAddress: "127.0.0.1", comment: DEFAULT_ACCESS_LIST_COMMENT }], + }); + }); + + it("should not fail if the current IP is already in the access list", async () => { + const apiClient = { + getIpInfo: vi.fn().mockResolvedValue({ currentIpv4Address: "127.0.0.1" } as never), + createProjectIpAccessList: vi + .fn() + .mockRejectedValue( + ApiClientError.fromError( + { status: 409, statusText: "Conflict" } as Response, + { message: "Conflict" } as never + ) as never + ), + } as unknown as ApiClient; + await ensureCurrentIpInAccessList(apiClient, "projectId"); + // eslint-disable-next-line @typescript-eslint/unbound-method + expect(apiClient.createProjectIpAccessList).toHaveBeenCalledWith({ + params: { path: { groupId: "projectId" } }, + body: [{ groupId: "projectId", ipAddress: "127.0.0.1", comment: DEFAULT_ACCESS_LIST_COMMENT }], + }); + }); +});