From 2febc5de923f6be7a360d7921c9dcae30b941462 Mon Sep 17 00:00:00 2001 From: gagik Date: Thu, 10 Apr 2025 19:41:31 +0200 Subject: [PATCH 1/9] refactor: remove type assertions and simplify state setup --- eslint.config.js | 6 ++++ src/common/atlas/apiClient.ts | 13 +++++++ src/index.ts | 38 +++++++++++++++----- src/server.ts | 61 --------------------------------- src/state.ts | 3 -- src/tools/atlas/createDBUser.ts | 2 +- src/tools/atlas/listClusters.ts | 8 ++--- src/tools/atlas/listDBUsers.ts | 2 +- src/tools/atlas/listProjects.ts | 6 ++-- tests/unit/index.test.ts | 6 ++-- tsconfig.json | 1 + 11 files changed, 60 insertions(+), 86 deletions(-) delete mode 100644 src/server.ts diff --git a/eslint.config.js b/eslint.config.js index e93a22bf..3784279d 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -9,5 +9,11 @@ export default defineConfig([ { files: ["src/**/*.ts"], languageOptions: { globals: globals.node } }, tseslint.configs.recommended, eslintConfigPrettier, + { + files: ["src/**/*.ts"], + rules: { + "@typescript-eslint/no-non-null-assertion": "error", + }, + }, globalIgnores(["node_modules", "dist"]), ]); diff --git a/src/common/atlas/apiClient.ts b/src/common/atlas/apiClient.ts index 0c6615d7..66cf7f4f 100644 --- a/src/common/atlas/apiClient.ts +++ b/src/common/atlas/apiClient.ts @@ -2,6 +2,7 @@ import config from "../../config.js"; import createClient, { FetchOptions, Middleware } from "openapi-fetch"; import { paths, operations } from "./openapi.js"; +import { State } from "../../state.js"; export interface OAuthToken { access_token: string; @@ -85,6 +86,18 @@ export class ApiClient { this.client.use(this.errorMiddleware()); } + static fromState(state: State): ApiClient { + return new ApiClient({ + token: state.credentials.auth.token, + saveToken: async (token) => { + state.credentials.auth.code = undefined; + state.credentials.auth.token = token; + state.credentials.auth.status = "issued"; + await state.persistCredentials(); + }, + }); + } + async storeToken(token: OAuthToken): Promise { this.token = token; diff --git a/src/index.ts b/src/index.ts index 39fb0bc8..a355b97d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,17 +1,37 @@ import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import { Server } from "./server.js"; import logger from "./logger.js"; import { mongoLogId } from "mongodb-log-writer"; +import { ApiClient } from "./common/atlas/apiClient.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import config from "./config.js"; +import { State } from "./state.js"; +import { registerAtlasTools } from "./tools/atlas/tools.js"; +import { registerMongoDBTools } from "./tools/mongodb/index.js"; export async function runServer() { - const server = new Server(); + try { + const state = new State(); + await state.loadCredentials(); - const transport = new StdioServerTransport(); - await server.connect(transport); -} + const apiClient = ApiClient.fromState(state); + + const mcp = new McpServer({ + name: "MongoDB Atlas", + version: config.version, + }); + + mcp.server.registerCapabilities({ logging: {} }); -runServer().catch((error) => { - logger.emergency(mongoLogId(1_000_004), "server", `Fatal error running server: ${error}`); + const transport = new StdioServerTransport(); + await mcp.server.connect(transport); + + registerAtlasTools(mcp, state, apiClient); + registerMongoDBTools(mcp, state); + } catch (error) { + logger.emergency(mongoLogId(1_000_004), "server", `Fatal error running server: ${error}`); + + process.exit(1); + } +} - process.exit(1); -}); +runServer(); diff --git a/src/server.ts b/src/server.ts deleted file mode 100644 index 0415f038..00000000 --- a/src/server.ts +++ /dev/null @@ -1,61 +0,0 @@ -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { ApiClient } from "./common/atlas/apiClient.js"; -import defaultState, { State } from "./state.js"; -import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; -import { registerAtlasTools } from "./tools/atlas/tools.js"; -import { registerMongoDBTools } from "./tools/mongodb/index.js"; -import config from "./config.js"; -import logger, { initializeLogger } from "./logger.js"; -import { mongoLogId } from "mongodb-log-writer"; - -export class Server { - state: State = defaultState; - apiClient: ApiClient | undefined = undefined; - initialized: boolean = false; - - private async init() { - if (this.initialized) { - return; - } - - await this.state.loadCredentials(); - - this.apiClient = new ApiClient({ - token: this.state.credentials.auth.token, - saveToken: async (token) => { - if (!this.state) { - throw new Error("State is not initialized"); - } - this.state.credentials.auth.code = undefined; - this.state.credentials.auth.token = token; - this.state.credentials.auth.status = "issued"; - await this.state.persistCredentials(); - }, - }); - - this.initialized = true; - } - - private createMcpServer(): McpServer { - const server = new McpServer({ - name: "MongoDB Atlas", - version: config.version, - }); - - server.server.registerCapabilities({ logging: {} }); - - registerAtlasTools(server, this.state, this.apiClient!); - registerMongoDBTools(server, this.state); - - return server; - } - - async connect(transport: Transport) { - await this.init(); - const server = this.createMcpServer(); - await server.connect(transport); - await initializeLogger(server); - - logger.info(mongoLogId(1_000_004), "server", `Server started with transport ${transport.constructor.name}`); - } -} diff --git a/src/state.ts b/src/state.ts index 9cc79626..6d293a27 100644 --- a/src/state.ts +++ b/src/state.ts @@ -40,6 +40,3 @@ export class State { } } } - -const defaultState = new State(); -export default defaultState; diff --git a/src/tools/atlas/createDBUser.ts b/src/tools/atlas/createDBUser.ts index d2a3b5d6..47459b0e 100644 --- a/src/tools/atlas/createDBUser.ts +++ b/src/tools/atlas/createDBUser.ts @@ -53,7 +53,7 @@ export class CreateDBUserTool extends AtlasToolBase { : undefined, } as CloudDatabaseUser; - await this.apiClient!.createDatabaseUser({ + await this.apiClient.createDatabaseUser({ params: { path: { groupId: projectId, diff --git a/src/tools/atlas/listClusters.ts b/src/tools/atlas/listClusters.ts index eda4d420..caeecb6f 100644 --- a/src/tools/atlas/listClusters.ts +++ b/src/tools/atlas/listClusters.ts @@ -47,8 +47,8 @@ export class ListClustersTool extends AtlasToolBase { if (!clusters?.results?.length) { throw new Error("No clusters found."); } - const rows = clusters - .results!.map((result) => { + const rows = clusters.results + .map((result) => { return (result.clusters || []).map((cluster) => { return { ...result, ...cluster, clusters: undefined }; }); @@ -75,8 +75,8 @@ ${rows}`, if (!clusters?.results?.length) { throw new Error("No clusters found."); } - const rows = clusters - .results!.map((cluster) => { + const rows = clusters.results + .map((cluster) => { const connectionString = cluster.connectionStrings?.standard || "N/A"; const mongoDBVersion = cluster.mongoDBVersion || "N/A"; return `${cluster.name} | ${cluster.stateName} | ${mongoDBVersion} | ${connectionString}`; diff --git a/src/tools/atlas/listDBUsers.ts b/src/tools/atlas/listDBUsers.ts index d49d981b..d4f1e19b 100644 --- a/src/tools/atlas/listDBUsers.ts +++ b/src/tools/atlas/listDBUsers.ts @@ -14,7 +14,7 @@ export class ListDBUsersTool extends AtlasToolBase { protected async execute({ projectId }: ToolArgs): Promise { await this.ensureAuthenticated(); - const data = await this.apiClient!.listDatabaseUsers({ + const data = await this.apiClient.listDatabaseUsers({ params: { path: { groupId: projectId, diff --git a/src/tools/atlas/listProjects.ts b/src/tools/atlas/listProjects.ts index 6b4b7d4a..438a3cd8 100644 --- a/src/tools/atlas/listProjects.ts +++ b/src/tools/atlas/listProjects.ts @@ -9,15 +9,15 @@ export class ListProjectsTool extends AtlasToolBase { protected async execute(): Promise { await this.ensureAuthenticated(); - const data = await this.apiClient!.listProjects(); + const data = await this.apiClient.listProjects(); if (!data?.results?.length) { throw new Error("No projects found in your MongoDB Atlas account."); } // Format projects as a table - const rows = data! - .results!.map((project) => { + const rows = data.results + .map((project) => { const createdAt = project.created ? new Date(project.created).toLocaleString() : "N/A"; return `${project.name} | ${project.id} | ${createdAt}`; }) diff --git a/tests/unit/index.test.ts b/tests/unit/index.test.ts index 8773fd75..f87229b9 100644 --- a/tests/unit/index.test.ts +++ b/tests/unit/index.test.ts @@ -1,6 +1,5 @@ import { describe, it } from "@jest/globals"; -import { runServer } from "../../src/index"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio"; +import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; // mock the StdioServerTransport jest.mock("@modelcontextprotocol/sdk/server/stdio"); @@ -21,7 +20,6 @@ jest.mock("../../src/server.ts", () => { describe("Server initialization", () => { it("should create a server instance", async () => { - await runServer(); - expect(StdioServerTransport).toHaveBeenCalled(); + await expect(StdioServerTransport).toHaveBeenCalled(); }); }); diff --git a/tsconfig.json b/tsconfig.json index a195f859..1fe57f10 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -6,6 +6,7 @@ "rootDir": "./src", "outDir": "./dist", "strict": true, + "strictNullChecks": true, "esModuleInterop": true, "types": ["node"], "sourceMap": true, From 6c8752bc806bc2152b1f98da2bef4c2c542aa9dd Mon Sep 17 00:00:00 2001 From: gagik Date: Thu, 10 Apr 2025 21:10:22 +0200 Subject: [PATCH 2/9] fix: tests and token refactor --- src/common/atlas/apiClient.ts | 48 +++++++++++++++++++---------------- src/index.ts | 36 ++++++++++++-------------- tests/unit/index.test.ts | 29 +++++++-------------- 3 files changed, 51 insertions(+), 62 deletions(-) diff --git a/src/common/atlas/apiClient.ts b/src/common/atlas/apiClient.ts index 66cf7f4f..e930d7bb 100644 --- a/src/common/atlas/apiClient.ts +++ b/src/common/atlas/apiClient.ts @@ -51,39 +51,40 @@ export interface ApiClientOptions { export class ApiClient { private token?: OAuthToken; - private saveToken?: saveTokenFunction; - private client = createClient({ + private readonly saveToken?: saveTokenFunction; + private readonly client = createClient({ baseUrl: config.apiBaseUrl, headers: { "User-Agent": config.userAgent, Accept: `application/vnd.atlas.${config.atlasApiVersion}+json`, }, }); - private authMiddleware = (apiClient: ApiClient): Middleware => ({ - async onRequest({ request, schemaPath }) { + + private readonly authMiddleware: Middleware = { + onRequest: async ({ request, schemaPath }) => { if (schemaPath.startsWith("/api/private/unauth") || schemaPath.startsWith("/api/oauth")) { return undefined; } - if (await apiClient.validateToken()) { - request.headers.set("Authorization", `Bearer ${apiClient.token!.access_token}`); + if (this.token && (await this.validateToken())) { + request.headers.set("Authorization", `Bearer ${this.token.access_token}`); return request; } }, - }); - private errorMiddleware = (): Middleware => ({ + }; + private readonly errorMiddleware: Middleware = { async onResponse({ response }) { if (!response.ok) { throw await ApiClientError.fromResponse(response); } }, - }); + }; constructor(options: ApiClientOptions) { const { token, saveToken } = options; this.token = token; this.saveToken = saveToken; - this.client.use(this.authMiddleware(this)); - this.client.use(this.errorMiddleware()); + this.client.use(this.authMiddleware); + this.client.use(this.errorMiddleware); } static fromState(state: State): ApiClient { @@ -173,7 +174,7 @@ export class ApiClient { } } - async refreshToken(token?: OAuthToken): Promise { + async refreshToken(token: OAuthToken): Promise { const endpoint = "api/private/unauth/account/device/token"; const url = new URL(endpoint, config.apiBaseUrl); const response = await fetch(url, { @@ -184,7 +185,7 @@ export class ApiClient { }, body: new URLSearchParams({ client_id: config.clientId, - refresh_token: (token || this.token)?.refresh_token || "", + refresh_token: token.refresh_token, grant_type: "refresh_token", scope: "openid profile offline_access", }).toString(), @@ -207,7 +208,7 @@ export class ApiClient { return await this.storeToken(tokenToStore); } - async revokeToken(token?: OAuthToken): Promise { + async revokeToken(token: OAuthToken): Promise { const endpoint = "api/private/unauth/account/device/token"; const url = new URL(endpoint, config.apiBaseUrl); const response = await fetch(url, { @@ -219,7 +220,7 @@ export class ApiClient { }, body: new URLSearchParams({ client_id: config.clientId, - token: (token || this.token)?.access_token || "", + token: token.access_token || "", token_type_hint: "refresh_token", }).toString(), }); @@ -235,9 +236,8 @@ export class ApiClient { return; } - private checkTokenExpiry(token?: OAuthToken): boolean { + private checkTokenExpiry(token: OAuthToken): boolean { try { - token = token || this.token; if (!token || !token.access_token) { return false; } @@ -252,13 +252,17 @@ export class ApiClient { } } - async validateToken(token?: OAuthToken): Promise { - if (this.checkTokenExpiry(token)) { + async validateToken(): Promise { + if (!this.token) { + return false; + } + + if (this.checkTokenExpiry(this.token)) { return true; } try { - await this.refreshToken(token); + await this.refreshToken(this.token); return true; } catch { return false; @@ -266,7 +270,7 @@ export class ApiClient { } async getIpInfo() { - if (!(await this.validateToken())) { + if (!this.token || !(await this.validateToken())) { throw new Error("Not Authenticated"); } @@ -276,7 +280,7 @@ export class ApiClient { method: "GET", headers: { Accept: "application/json", - Authorization: `Bearer ${this.token!.access_token}`, + Authorization: `Bearer ${this.token.access_token}`, "User-Agent": config.userAgent, }, }); diff --git a/src/index.ts b/src/index.ts index a355b97d..dfff1475 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,30 +8,26 @@ import { State } from "./state.js"; import { registerAtlasTools } from "./tools/atlas/tools.js"; import { registerMongoDBTools } from "./tools/mongodb/index.js"; -export async function runServer() { - try { - const state = new State(); - await state.loadCredentials(); +try { + const state = new State(); + await state.loadCredentials(); - const apiClient = ApiClient.fromState(state); + const apiClient = ApiClient.fromState(state); - const mcp = new McpServer({ - name: "MongoDB Atlas", - version: config.version, - }); + const mcp = new McpServer({ + name: "MongoDB Atlas", + version: config.version, + }); - mcp.server.registerCapabilities({ logging: {} }); + mcp.server.registerCapabilities({ logging: {} }); - const transport = new StdioServerTransport(); - await mcp.server.connect(transport); + registerAtlasTools(mcp, state, apiClient); + registerMongoDBTools(mcp, state); - registerAtlasTools(mcp, state, apiClient); - registerMongoDBTools(mcp, state); - } catch (error) { - logger.emergency(mongoLogId(1_000_004), "server", `Fatal error running server: ${error}`); + const transport = new StdioServerTransport(); + await mcp.server.connect(transport); +} catch (error) { + logger.emergency(mongoLogId(1_000_004), "server", `Fatal error running server: ${error}`); - process.exit(1); - } + process.exit(1); } - -runServer(); diff --git a/tests/unit/index.test.ts b/tests/unit/index.test.ts index f87229b9..a1b413f2 100644 --- a/tests/unit/index.test.ts +++ b/tests/unit/index.test.ts @@ -1,25 +1,14 @@ import { describe, it } from "@jest/globals"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; - -// mock the StdioServerTransport -jest.mock("@modelcontextprotocol/sdk/server/stdio"); -// mock Server class and its methods -jest.mock("../../src/server.ts", () => { - return { - Server: jest.fn().mockImplementation(() => { - return { - connect: jest.fn().mockImplementation((transport) => { - return new Promise((resolve) => { - resolve(transport); - }); - }), - }; - }), - }; -}); +import { State } from "../../src/state"; describe("Server initialization", () => { - it("should create a server instance", async () => { - await expect(StdioServerTransport).toHaveBeenCalled(); + it("should define a default state", async () => { + const state = new State(); + + expect(state.credentials).toEqual({ + auth: { + status: "not_auth", + }, + }); }); }); From af90a87ddb8ffa2199bfb8bfbb5f8625593ac520 Mon Sep 17 00:00:00 2001 From: gagik Date: Thu, 10 Apr 2025 16:06:43 +0200 Subject: [PATCH 3/9] chore: support common.js c --- .../mongodb/metadata/collectionExplain.ts | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 src/tools/mongodb/metadata/collectionExplain.ts diff --git a/src/tools/mongodb/metadata/collectionExplain.ts b/src/tools/mongodb/metadata/collectionExplain.ts new file mode 100644 index 00000000..6d835b33 --- /dev/null +++ b/src/tools/mongodb/metadata/collectionExplain.ts @@ -0,0 +1,67 @@ +import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; +import { ToolArgs } from "../../tool.js"; +import { parseSchema, SchemaField } from "mongodb-schema"; +import { z } from "zod"; +import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; +import { ExplainVerbosity } from "mongodb"; + +export class CollectionExplainTool extends MongoDBToolBase { + protected name = "collection-explain"; + protected description = "Returns statistics describing the execution of the winning plan for the evaluated method"; + + static supportedOperations = [ + "aggregate", + "count", + "distinct", + "find", + "findAndModify", + "delete", + "mapReduce", + "update", + ] as const; + + protected argsShape = { + ...DbOperationArgs, + operation: z.enum(CollectionExplainTool.supportedOperations).describe("Method to explain"), + operationsArguments: z.any().describe("Arguments used by the method to be explained"), + }; + + protected operationType: DbOperationType = "metadata"; + + protected async execute({ + database, + collection, + operation, + operationsArguments, + }: ToolArgs): Promise { + const provider = this.ensureConnected(); + + const documents = await provider.runCommand(database, { + explain: operation, + verbosity: ExplainVerbosity.queryPlanner, + }); + + return { + content: [ + { + text: `Found ${schema.fields.length} fields in the schema for \`${database}.${collection}\``, + type: "text", + }, + { + text: this.formatFieldOutput(schema.fields), + type: "text", + }, + ], + }; + } + + private formatFieldOutput(fields: SchemaField[]): string { + let result = "| Field | Type | Confidence |\n"; + result += "|-------|------|-------------|\n"; + for (const field of fields) { + result += `| ${field.name} | \`${field.type}\` | ${(field.probability * 100).toFixed(0)}% |\n`; + } + return result; + } +} From a0bbbe96182829f821656b1bbc6c79efdbf4c4cd Mon Sep 17 00:00:00 2001 From: gagik Date: Thu, 10 Apr 2025 21:34:54 +0200 Subject: [PATCH 4/9] feat: add explain tool --- src/tools/mongodb/index.ts | 2 + .../mongodb/metadata/collectionExplain.ts | 67 ------------------ src/tools/mongodb/metadata/explain.ts | 69 +++++++++++++++++++ 3 files changed, 71 insertions(+), 67 deletions(-) delete mode 100644 src/tools/mongodb/metadata/collectionExplain.ts create mode 100644 src/tools/mongodb/metadata/explain.ts diff --git a/src/tools/mongodb/index.ts b/src/tools/mongodb/index.ts index be30c494..33f04314 100644 --- a/src/tools/mongodb/index.ts +++ b/src/tools/mongodb/index.ts @@ -20,6 +20,7 @@ import { UpdateManyTool } from "./update/updateMany.js"; import { RenameCollectionTool } from "./update/renameCollection.js"; import { DropDatabaseTool } from "./delete/dropDatabase.js"; import { DropCollectionTool } from "./delete/dropCollection.js"; +import { ExplainTool } from "./metadata/explain.js"; export function registerMongoDBTools(server: McpServer, state: State) { const tools = [ @@ -43,6 +44,7 @@ export function registerMongoDBTools(server: McpServer, state: State) { RenameCollectionTool, DropDatabaseTool, DropCollectionTool, + ExplainTool, ]; for (const tool of tools) { diff --git a/src/tools/mongodb/metadata/collectionExplain.ts b/src/tools/mongodb/metadata/collectionExplain.ts deleted file mode 100644 index 6d835b33..00000000 --- a/src/tools/mongodb/metadata/collectionExplain.ts +++ /dev/null @@ -1,67 +0,0 @@ -import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; -import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; -import { ToolArgs } from "../../tool.js"; -import { parseSchema, SchemaField } from "mongodb-schema"; -import { z } from "zod"; -import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; -import { ExplainVerbosity } from "mongodb"; - -export class CollectionExplainTool extends MongoDBToolBase { - protected name = "collection-explain"; - protected description = "Returns statistics describing the execution of the winning plan for the evaluated method"; - - static supportedOperations = [ - "aggregate", - "count", - "distinct", - "find", - "findAndModify", - "delete", - "mapReduce", - "update", - ] as const; - - protected argsShape = { - ...DbOperationArgs, - operation: z.enum(CollectionExplainTool.supportedOperations).describe("Method to explain"), - operationsArguments: z.any().describe("Arguments used by the method to be explained"), - }; - - protected operationType: DbOperationType = "metadata"; - - protected async execute({ - database, - collection, - operation, - operationsArguments, - }: ToolArgs): Promise { - const provider = this.ensureConnected(); - - const documents = await provider.runCommand(database, { - explain: operation, - verbosity: ExplainVerbosity.queryPlanner, - }); - - return { - content: [ - { - text: `Found ${schema.fields.length} fields in the schema for \`${database}.${collection}\``, - type: "text", - }, - { - text: this.formatFieldOutput(schema.fields), - type: "text", - }, - ], - }; - } - - private formatFieldOutput(fields: SchemaField[]): string { - let result = "| Field | Type | Confidence |\n"; - result += "|-------|------|-------------|\n"; - for (const field of fields) { - result += `| ${field.name} | \`${field.type}\` | ${(field.probability * 100).toFixed(0)}% |\n`; - } - return result; - } -} diff --git a/src/tools/mongodb/metadata/explain.ts b/src/tools/mongodb/metadata/explain.ts new file mode 100644 index 00000000..2d4e173c --- /dev/null +++ b/src/tools/mongodb/metadata/explain.ts @@ -0,0 +1,69 @@ +import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; +import { ToolArgs } from "../../tool.js"; +import { z } from "zod"; +import { ExplainVerbosity, Document } from "mongodb"; + +export class ExplainTool extends MongoDBToolBase { + protected name = "explain"; + protected description = + "Returns statistics describing the execution of the winning plan chosen by the query optimizer for the evaluated method"; + + protected argsShape = { + ...DbOperationArgs, + method: z.enum(["aggregate", "find"]).describe("The method to run"), + methodArguments: z + .object({ + aggregatePipeline: z + .array(z.object({}).passthrough()) + .optional() + .describe("aggregate - array of aggregation stages to execute"), + + findQuery: z.object({}).passthrough().optional().describe("find - The query to run"), + findProjection: z.object({}).passthrough().optional().describe("find - The projection to apply"), + }) + .describe("The arguments for the method"), + }; + + protected operationType: DbOperationType = "metadata"; + + protected async execute({ + database, + collection, + method, + methodArguments, + }: ToolArgs): Promise { + const provider = this.ensureConnected(); + + let result: Document; + switch (method) { + case "aggregate": { + result = await provider.aggregate(database, collection).explain(); + break; + } + case "find": { + const query = methodArguments.findQuery ?? {}; + const projection = methodArguments.findProjection ?? {}; + result = await provider + .find(database, collection, query, { projection }) + .explain(ExplainVerbosity.queryPlanner); + break; + } + default: + throw new Error(`Unsupported method: ${method}`); + } + + return { + content: [ + { + text: `Here is some information about the winning plan chosen by the query optimizer for running the given \`${method}\` operation in \`${database}\`. This information can be used to understand how the query was executed and to optimize the query performance.`, + type: "text", + }, + { + text: JSON.stringify(result), + type: "text", + }, + ], + }; + } +} From 599a179d1379308047f83fd6a7f8bba0103e5368 Mon Sep 17 00:00:00 2001 From: gagik Date: Fri, 11 Apr 2025 10:33:03 +0200 Subject: [PATCH 5/9] fix: use union types and reuse argument definitions --- src/tools/mongodb/metadata/explain.ts | 66 +++++++++++++++++++-------- src/tools/mongodb/read/aggregate.ts | 16 ++++--- src/tools/mongodb/read/count.ts | 18 +++++--- src/tools/mongodb/read/find.ts | 41 +++++++++-------- 4 files changed, 88 insertions(+), 53 deletions(-) diff --git a/src/tools/mongodb/metadata/explain.ts b/src/tools/mongodb/metadata/explain.ts index 2d4e173c..7575131c 100644 --- a/src/tools/mongodb/metadata/explain.ts +++ b/src/tools/mongodb/metadata/explain.ts @@ -3,6 +3,9 @@ import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbToo import { ToolArgs } from "../../tool.js"; import { z } from "zod"; import { ExplainVerbosity, Document } from "mongodb"; +import { AggregateArgs } from "../read/aggregate.js"; +import { FindArgs } from "../read/find.js"; +import { CountArgs } from "../read/count.js"; export class ExplainTool extends MongoDBToolBase { protected name = "explain"; @@ -11,42 +14,65 @@ export class ExplainTool extends MongoDBToolBase { protected argsShape = { ...DbOperationArgs, - method: z.enum(["aggregate", "find"]).describe("The method to run"), - methodArguments: z - .object({ - aggregatePipeline: z - .array(z.object({}).passthrough()) - .optional() - .describe("aggregate - array of aggregation stages to execute"), - - findQuery: z.object({}).passthrough().optional().describe("find - The query to run"), - findProjection: z.object({}).passthrough().optional().describe("find - The projection to apply"), - }) - .describe("The arguments for the method"), + method: z + .array( + z.union([ + z.object({ + name: z.literal("aggregate"), + arguments: z.object(AggregateArgs), + }), + z.object({ + name: z.literal("find"), + arguments: z.object(FindArgs), + }), + z.object({ + name: z.literal("count"), + arguments: z.object(CountArgs), + }), + ]) + ) + .describe("The method and its arguments to run"), }; protected operationType: DbOperationType = "metadata"; + static readonly defaultVerbosity = ExplainVerbosity.queryPlanner; + protected async execute({ database, collection, - method, - methodArguments, + method: methods, }: ToolArgs): Promise { const provider = this.ensureConnected(); + const method = methods[0]; + + if (!method) { + throw new Error("No method provided"); + } let result: Document; - switch (method) { + switch (method.name) { case "aggregate": { - result = await provider.aggregate(database, collection).explain(); + const { pipeline, limit } = method.arguments; + result = await provider + .aggregate(database, collection, pipeline) + .limit(limit) + .explain(ExplainTool.defaultVerbosity); break; } case "find": { - const query = methodArguments.findQuery ?? {}; - const projection = methodArguments.findProjection ?? {}; + const { filter, ...rest } = method.arguments; result = await provider - .find(database, collection, query, { projection }) - .explain(ExplainVerbosity.queryPlanner); + .find(database, collection, filter, { ...rest }) + .explain(ExplainTool.defaultVerbosity); + break; + } + case "count": { + const { query } = method.arguments; + // This helper doesn't have explain() command but does have the argument explain + result = (await provider.count(database, collection, query, { + explain: ExplainTool.defaultVerbosity, + })) as unknown as Document; break; } default: diff --git a/src/tools/mongodb/read/aggregate.ts b/src/tools/mongodb/read/aggregate.ts index f6acb18f..0194d1f3 100644 --- a/src/tools/mongodb/read/aggregate.ts +++ b/src/tools/mongodb/read/aggregate.ts @@ -1,16 +1,19 @@ import { z } from "zod"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; -import { DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; +import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs } from "../../tool.js"; +export const AggregateArgs = { + pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"), + limit: z.number().optional().default(10).describe("The maximum number of documents to return"), +}; + export class AggregateTool extends MongoDBToolBase { protected name = "aggregate"; protected description = "Run an aggregation against a MongoDB collection"; protected argsShape = { - collection: z.string().describe("Collection name"), - database: z.string().describe("Database name"), - pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"), - limit: z.number().optional().default(10).describe("The maximum number of documents to return"), + ...DbOperationArgs, + ...AggregateArgs, }; protected operationType: DbOperationType = "read"; @@ -18,9 +21,10 @@ export class AggregateTool extends MongoDBToolBase { database, collection, pipeline, + limit, }: ToolArgs): Promise { const provider = this.ensureConnected(); - const documents = await provider.aggregate(database, collection, pipeline).toArray(); + const documents = await provider.aggregate(database, collection, pipeline).limit(limit).toArray(); const content: Array<{ text: string; type: "text" }> = [ { diff --git a/src/tools/mongodb/read/count.ts b/src/tools/mongodb/read/count.ts index 956e160c..390c35a2 100644 --- a/src/tools/mongodb/read/count.ts +++ b/src/tools/mongodb/read/count.ts @@ -3,18 +3,22 @@ import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbToo import { ToolArgs } from "../../tool.js"; import { z } from "zod"; +export const CountArgs = { + query: z + .object({}) + .passthrough() + .optional() + .describe( + "The query filter to count documents. Matches the syntax of the filter argument of db.collection.count()" + ), +}; + export class CountTool extends MongoDBToolBase { protected name = "count"; protected description = "Gets the number of documents in a MongoDB collection"; protected argsShape = { ...DbOperationArgs, - query: z - .object({}) - .passthrough() - .optional() - .describe( - "The query filter to count documents. Matches the syntax of the filter argument of db.collection.count()" - ), + ...CountArgs, }; protected operationType: DbOperationType = "metadata"; diff --git a/src/tools/mongodb/read/find.ts b/src/tools/mongodb/read/find.ts index 04392139..4eaf11be 100644 --- a/src/tools/mongodb/read/find.ts +++ b/src/tools/mongodb/read/find.ts @@ -1,32 +1,33 @@ import { z } from "zod"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; -import { DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; +import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs } from "../../tool.js"; import { SortDirection } from "mongodb"; +export const FindArgs = { + filter: z + .object({}) + .passthrough() + .optional() + .describe("The query filter, matching the syntax of the query argument of db.collection.find()"), + projection: z + .object({}) + .passthrough() + .optional() + .describe("The projection, matching the syntax of the projection argument of db.collection.find()"), + limit: z.number().optional().default(10).describe("The maximum number of documents to return"), + sort: z + .record(z.string(), z.custom()) + .optional() + .describe("A document, describing the sort order, matching the syntax of the sort argument of cursor.sort()"), +}; + export class FindTool extends MongoDBToolBase { protected name = "find"; protected description = "Run a find query against a MongoDB collection"; protected argsShape = { - collection: z.string().describe("Collection name"), - database: z.string().describe("Database name"), - filter: z - .object({}) - .passthrough() - .optional() - .describe("The query filter, matching the syntax of the query argument of db.collection.find()"), - projection: z - .object({}) - .passthrough() - .optional() - .describe("The projection, matching the syntax of the projection argument of db.collection.find()"), - limit: z.number().optional().default(10).describe("The maximum number of documents to return"), - sort: z - .record(z.string(), z.custom()) - .optional() - .describe( - "A document, describing the sort order, matching the syntax of the sort argument of cursor.sort()" - ), + ...DbOperationArgs, + ...FindArgs, }; protected operationType: DbOperationType = "read"; From e51e7ed938356c5adfb1d6f4b72a58dda701b69c Mon Sep 17 00:00:00 2001 From: Gagik Amaryan Date: Fri, 11 Apr 2025 14:43:02 +0200 Subject: [PATCH 6/9] Update src/tools/mongodb/metadata/explain.ts Co-authored-by: Nikola Irinchev --- src/tools/mongodb/metadata/explain.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/mongodb/metadata/explain.ts b/src/tools/mongodb/metadata/explain.ts index 7575131c..adadf574 100644 --- a/src/tools/mongodb/metadata/explain.ts +++ b/src/tools/mongodb/metadata/explain.ts @@ -82,7 +82,7 @@ export class ExplainTool extends MongoDBToolBase { return { content: [ { - text: `Here is some information about the winning plan chosen by the query optimizer for running the given \`${method}\` operation in \`${database}\`. This information can be used to understand how the query was executed and to optimize the query performance.`, + text: `Here is some information about the winning plan chosen by the query optimizer for running the given \`${method}\` operation in \`${database}.${collection}\`. This information can be used to understand how the query was executed and to optimize the query performance.`, type: "text", }, { From d78d7f62e7ffb06c53f01b596732c5863a4539c8 Mon Sep 17 00:00:00 2001 From: gagik Date: Fri, 11 Apr 2025 15:29:46 +0200 Subject: [PATCH 7/9] fix: remove limit --- src/tools/mongodb/metadata/explain.ts | 7 ++----- src/tools/mongodb/read/aggregate.ts | 3 +-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/tools/mongodb/metadata/explain.ts b/src/tools/mongodb/metadata/explain.ts index adadf574..5cea2c96 100644 --- a/src/tools/mongodb/metadata/explain.ts +++ b/src/tools/mongodb/metadata/explain.ts @@ -53,11 +53,8 @@ export class ExplainTool extends MongoDBToolBase { let result: Document; switch (method.name) { case "aggregate": { - const { pipeline, limit } = method.arguments; - result = await provider - .aggregate(database, collection, pipeline) - .limit(limit) - .explain(ExplainTool.defaultVerbosity); + const { pipeline } = method.arguments; + result = await provider.aggregate(database, collection, pipeline).explain(ExplainTool.defaultVerbosity); break; } case "find": { diff --git a/src/tools/mongodb/read/aggregate.ts b/src/tools/mongodb/read/aggregate.ts index 0194d1f3..a1aa83c4 100644 --- a/src/tools/mongodb/read/aggregate.ts +++ b/src/tools/mongodb/read/aggregate.ts @@ -21,10 +21,9 @@ export class AggregateTool extends MongoDBToolBase { database, collection, pipeline, - limit, }: ToolArgs): Promise { const provider = this.ensureConnected(); - const documents = await provider.aggregate(database, collection, pipeline).limit(limit).toArray(); + const documents = await provider.aggregate(database, collection, pipeline).toArray(); const content: Array<{ text: string; type: "text" }> = [ { From 7bf120e600fab8dcbc30563305263c28d396bee5 Mon Sep 17 00:00:00 2001 From: gagik Date: Fri, 11 Apr 2025 17:01:34 +0200 Subject: [PATCH 8/9] fix: align with main --- src/common/atlas/apiClient.ts | 75 ++++++++++++------------------ src/common/atlas/apiClientError.ts | 21 +++++++++ src/tools/mongodb/tools.ts | 1 + 3 files changed, 52 insertions(+), 45 deletions(-) create mode 100644 src/common/atlas/apiClientError.ts diff --git a/src/common/atlas/apiClient.ts b/src/common/atlas/apiClient.ts index d0e6e843..b784e43e 100644 --- a/src/common/atlas/apiClient.ts +++ b/src/common/atlas/apiClient.ts @@ -1,32 +1,11 @@ import config from "../../config.js"; import createClient, { Client, FetchOptions, Middleware } from "openapi-fetch"; import { AccessToken, ClientCredentials } from "simple-oauth2"; - +import { ApiClientError } from "./apiClientError.js"; import { paths, operations } from "./openapi.js"; -import { State } from "../../state.js"; const ATLAS_API_VERSION = "2025-03-12"; -export class ApiClientError extends Error { - response?: Response; - - constructor(message: string, response: Response | undefined = undefined) { - super(message); - this.name = "ApiClientError"; - this.response = response; - } - - static async fromResponse(response: Response, message?: string): Promise { - message ||= `error calling Atlas API`; - try { - const text = await response.text(); - return new ApiClientError(`${message}: [${response.status} ${response.statusText}] ${text}`, response); - } catch { - return new ApiClientError(`${message}: ${response.status} ${response.statusText}`, response); - } - } -} - export interface ApiClientOptions { credentials?: { clientId: string; @@ -71,6 +50,7 @@ export class ApiClient { } }, }; + private readonly errorMiddleware: Middleware = { async onResponse({ response }) { if (!response.ok) { @@ -79,15 +59,13 @@ export class ApiClient { }, }; - constructor(options: ApiClientOptions) { - const defaultOptions = { - baseUrl: "https://cloud.mongodb.com/", - userAgent: `AtlasMCP/${config.version} (${process.platform}; ${process.arch}; ${process.env.HOSTNAME || "unknown"})`, - }; - + constructor(options?: ApiClientOptions) { this.options = { - ...defaultOptions, ...options, + baseUrl: options?.baseUrl || "https://cloud.mongodb.com/", + userAgent: + options?.userAgent || + `AtlasMCP/${config.version} (${process.platform}; ${process.arch}; ${process.env.HOSTNAME || "unknown"})`, }; this.client = createClient({ @@ -138,38 +116,39 @@ export class ApiClient { }>; } - async listProjects(options?: FetchOptions) { - const { data } = await this.client.GET(`/api/atlas/v2/groups`, options); + // DO NOT EDIT. This is auto-generated code. + async listClustersForAllProjects(options?: FetchOptions) { + const { data } = await this.client.GET("/api/atlas/v2/clusters", options); return data; } - async listProjectIpAccessLists(options: FetchOptions) { - const { data } = await this.client.GET(`/api/atlas/v2/groups/{groupId}/accessList`, options); + async listProjects(options?: FetchOptions) { + const { data } = await this.client.GET("/api/atlas/v2/groups", options); return data; } - async createProjectIpAccessList(options: FetchOptions) { - const { data } = await this.client.POST(`/api/atlas/v2/groups/{groupId}/accessList`, options); + async createProject(options: FetchOptions) { + const { data } = await this.client.POST("/api/atlas/v2/groups", options); return data; } async getProject(options: FetchOptions) { - const { data } = await this.client.GET(`/api/atlas/v2/groups/{groupId}`, options); + const { data } = await this.client.GET("/api/atlas/v2/groups/{groupId}", options); return data; } - async listClusters(options: FetchOptions) { - const { data } = await this.client.GET(`/api/atlas/v2/groups/{groupId}/clusters`, options); + async listProjectIpAccessLists(options: FetchOptions) { + const { data } = await this.client.GET("/api/atlas/v2/groups/{groupId}/accessList", options); return data; } - async listClustersForAllProjects(options?: FetchOptions) { - const { data } = await this.client.GET(`/api/atlas/v2/clusters`, options); + async createProjectIpAccessList(options: FetchOptions) { + const { data } = await this.client.POST("/api/atlas/v2/groups/{groupId}/accessList", options); return data; } - async getCluster(options: FetchOptions) { - const { data } = await this.client.GET(`/api/atlas/v2/groups/{groupId}/clusters/{clusterName}`, options); + async listClusters(options: FetchOptions) { + const { data } = await this.client.GET("/api/atlas/v2/groups/{groupId}/clusters", options); return data; } @@ -178,13 +157,19 @@ export class ApiClient { return data; } - async createDatabaseUser(options: FetchOptions) { - const { data } = await this.client.POST("/api/atlas/v2/groups/{groupId}/databaseUsers", options); + async getCluster(options: FetchOptions) { + const { data } = await this.client.GET("/api/atlas/v2/groups/{groupId}/clusters/{clusterName}", options); return data; } async listDatabaseUsers(options: FetchOptions) { - const { data } = await this.client.GET(`/api/atlas/v2/groups/{groupId}/databaseUsers`, options); + const { data } = await this.client.GET("/api/atlas/v2/groups/{groupId}/databaseUsers", options); + return data; + } + + async createDatabaseUser(options: FetchOptions) { + const { data } = await this.client.POST("/api/atlas/v2/groups/{groupId}/databaseUsers", options); return data; } + // DO NOT EDIT. This is auto-generated code. } diff --git a/src/common/atlas/apiClientError.ts b/src/common/atlas/apiClientError.ts new file mode 100644 index 00000000..6073c161 --- /dev/null +++ b/src/common/atlas/apiClientError.ts @@ -0,0 +1,21 @@ +export class ApiClientError extends Error { + response?: Response; + + constructor(message: string, response: Response | undefined = undefined) { + super(message); + this.name = "ApiClientError"; + this.response = response; + } + + static async fromResponse( + response: Response, + message: string = `error calling Atlas API` + ): Promise { + try { + const text = await response.text(); + return new ApiClientError(`${message}: [${response.status} ${response.statusText}] ${text}`, response); + } catch { + return new ApiClientError(`${message}: ${response.status} ${response.statusText}`, response); + } + } +} diff --git a/src/tools/mongodb/tools.ts b/src/tools/mongodb/tools.ts index 69bff401..ac22e095 100644 --- a/src/tools/mongodb/tools.ts +++ b/src/tools/mongodb/tools.ts @@ -41,4 +41,5 @@ export const MongoDbTools = [ RenameCollectionTool, DropDatabaseTool, DropCollectionTool, + ExplainTool, ]; From 688ad56ede1f2e2be5b720bf87921aca51dccb27 Mon Sep 17 00:00:00 2001 From: gagik Date: Fri, 11 Apr 2025 17:03:36 +0200 Subject: [PATCH 9/9] fix: linting --- src/tools/mongodb/metadata/explain.ts | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/tools/mongodb/metadata/explain.ts b/src/tools/mongodb/metadata/explain.ts index 5cea2c96..4a750a1f 100644 --- a/src/tools/mongodb/metadata/explain.ts +++ b/src/tools/mongodb/metadata/explain.ts @@ -43,7 +43,7 @@ export class ExplainTool extends MongoDBToolBase { collection, method: methods, }: ToolArgs): Promise { - const provider = this.ensureConnected(); + const provider = await this.ensureConnected(); const method = methods[0]; if (!method) { @@ -60,7 +60,7 @@ export class ExplainTool extends MongoDBToolBase { case "find": { const { filter, ...rest } = method.arguments; result = await provider - .find(database, collection, filter, { ...rest }) + .find(database, collection, filter as Document, { ...rest }) .explain(ExplainTool.defaultVerbosity); break; } @@ -72,14 +72,12 @@ export class ExplainTool extends MongoDBToolBase { })) as unknown as Document; break; } - default: - throw new Error(`Unsupported method: ${method}`); } return { content: [ { - text: `Here is some information about the winning plan chosen by the query optimizer for running the given \`${method}\` operation in \`${database}.${collection}\`. This information can be used to understand how the query was executed and to optimize the query performance.`, + text: `Here is some information about the winning plan chosen by the query optimizer for running the given \`${method.name}\` operation in \`${database}.${collection}\`. This information can be used to understand how the query was executed and to optimize the query performance.`, type: "text", }, {