diff --git a/src/tools/mongodb/metadata/explain.ts b/src/tools/mongodb/metadata/explain.ts new file mode 100644 index 00000000..4a750a1f --- /dev/null +++ b/src/tools/mongodb/metadata/explain.ts @@ -0,0 +1,90 @@ +import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; +import { ToolArgs } from "../../tool.js"; +import { z } from "zod"; +import { ExplainVerbosity, Document } from "mongodb"; +import { AggregateArgs } from "../read/aggregate.js"; +import { FindArgs } from "../read/find.js"; +import { CountArgs } from "../read/count.js"; + +export class ExplainTool extends MongoDBToolBase { + protected name = "explain"; + protected description = + "Returns statistics describing the execution of the winning plan chosen by the query optimizer for the evaluated method"; + + protected argsShape = { + ...DbOperationArgs, + method: z + .array( + z.union([ + z.object({ + name: z.literal("aggregate"), + arguments: z.object(AggregateArgs), + }), + z.object({ + name: z.literal("find"), + arguments: z.object(FindArgs), + }), + z.object({ + name: z.literal("count"), + arguments: z.object(CountArgs), + }), + ]) + ) + .describe("The method and its arguments to run"), + }; + + protected operationType: DbOperationType = "metadata"; + + static readonly defaultVerbosity = ExplainVerbosity.queryPlanner; + + protected async execute({ + database, + collection, + method: methods, + }: ToolArgs): Promise { + const provider = await this.ensureConnected(); + const method = methods[0]; + + if (!method) { + throw new Error("No method provided"); + } + + let result: Document; + switch (method.name) { + case "aggregate": { + const { pipeline } = method.arguments; + result = await provider.aggregate(database, collection, pipeline).explain(ExplainTool.defaultVerbosity); + break; + } + case "find": { + const { filter, ...rest } = method.arguments; + result = await provider + .find(database, collection, filter as Document, { ...rest }) + .explain(ExplainTool.defaultVerbosity); + break; + } + case "count": { + const { query } = method.arguments; + // This helper doesn't have explain() command but does have the argument explain + result = (await provider.count(database, collection, query, { + explain: ExplainTool.defaultVerbosity, + })) as unknown as Document; + break; + } + } + + return { + content: [ + { + text: `Here is some information about the winning plan chosen by the query optimizer for running the given \`${method.name}\` operation in \`${database}.${collection}\`. This information can be used to understand how the query was executed and to optimize the query performance.`, + type: "text", + }, + { + text: JSON.stringify(result), + type: "text", + }, + ], + }; + } +} diff --git a/src/tools/mongodb/read/aggregate.ts b/src/tools/mongodb/read/aggregate.ts index 66bc1edb..abdbf70e 100644 --- a/src/tools/mongodb/read/aggregate.ts +++ b/src/tools/mongodb/read/aggregate.ts @@ -1,16 +1,19 @@ import { z } from "zod"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; -import { DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; +import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs } from "../../tool.js"; +export const AggregateArgs = { + pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"), + limit: z.number().optional().default(10).describe("The maximum number of documents to return"), +}; + export class AggregateTool extends MongoDBToolBase { protected name = "aggregate"; protected description = "Run an aggregation against a MongoDB collection"; protected argsShape = { - collection: z.string().describe("Collection name"), - database: z.string().describe("Database name"), - pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"), - limit: z.number().optional().default(10).describe("The maximum number of documents to return"), + ...DbOperationArgs, + ...AggregateArgs, }; protected operationType: DbOperationType = "read"; diff --git a/src/tools/mongodb/read/count.ts b/src/tools/mongodb/read/count.ts index 8c5f446d..63d052bb 100644 --- a/src/tools/mongodb/read/count.ts +++ b/src/tools/mongodb/read/count.ts @@ -3,18 +3,22 @@ import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbToo import { ToolArgs } from "../../tool.js"; import { z } from "zod"; +export const CountArgs = { + query: z + .object({}) + .passthrough() + .optional() + .describe( + "The query filter to count documents. Matches the syntax of the filter argument of db.collection.count()" + ), +}; + export class CountTool extends MongoDBToolBase { protected name = "count"; protected description = "Gets the number of documents in a MongoDB collection"; protected argsShape = { ...DbOperationArgs, - query: z - .object({}) - .passthrough() - .optional() - .describe( - "The query filter to count documents. Matches the syntax of the filter argument of db.collection.count()" - ), + ...CountArgs, }; protected operationType: DbOperationType = "metadata"; diff --git a/src/tools/mongodb/read/find.ts b/src/tools/mongodb/read/find.ts index 54edce8e..9893891a 100644 --- a/src/tools/mongodb/read/find.ts +++ b/src/tools/mongodb/read/find.ts @@ -1,32 +1,33 @@ import { z } from "zod"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; -import { DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; +import { DbOperationArgs, DbOperationType, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs } from "../../tool.js"; import { SortDirection } from "mongodb"; +export const FindArgs = { + filter: z + .object({}) + .passthrough() + .optional() + .describe("The query filter, matching the syntax of the query argument of db.collection.find()"), + projection: z + .object({}) + .passthrough() + .optional() + .describe("The projection, matching the syntax of the projection argument of db.collection.find()"), + limit: z.number().optional().default(10).describe("The maximum number of documents to return"), + sort: z + .record(z.string(), z.custom()) + .optional() + .describe("A document, describing the sort order, matching the syntax of the sort argument of cursor.sort()"), +}; + export class FindTool extends MongoDBToolBase { protected name = "find"; protected description = "Run a find query against a MongoDB collection"; protected argsShape = { - collection: z.string().describe("Collection name"), - database: z.string().describe("Database name"), - filter: z - .object({}) - .passthrough() - .optional() - .describe("The query filter, matching the syntax of the query argument of db.collection.find()"), - projection: z - .object({}) - .passthrough() - .optional() - .describe("The projection, matching the syntax of the projection argument of db.collection.find()"), - limit: z.number().optional().default(10).describe("The maximum number of documents to return"), - sort: z - .record(z.string(), z.custom()) - .optional() - .describe( - "A document, describing the sort order, matching the syntax of the sort argument of cursor.sort()" - ), + ...DbOperationArgs, + ...FindArgs, }; protected operationType: DbOperationType = "read"; diff --git a/src/tools/mongodb/tools.ts b/src/tools/mongodb/tools.ts index 0f89335d..ac22e095 100644 --- a/src/tools/mongodb/tools.ts +++ b/src/tools/mongodb/tools.ts @@ -18,6 +18,7 @@ import { UpdateManyTool } from "./update/updateMany.js"; import { RenameCollectionTool } from "./update/renameCollection.js"; import { DropDatabaseTool } from "./delete/dropDatabase.js"; import { DropCollectionTool } from "./delete/dropCollection.js"; +import { ExplainTool } from "./metadata/explain.js"; export const MongoDbTools = [ ConnectTool, @@ -40,4 +41,5 @@ export const MongoDbTools = [ RenameCollectionTool, DropDatabaseTool, DropCollectionTool, + ExplainTool, ];