diff --git a/src/common/exportsManager.ts b/src/common/exportsManager.ts index 384203092..9235e8b25 100644 --- a/src/common/exportsManager.ts +++ b/src/common/exportsManager.ts @@ -3,7 +3,7 @@ import path from "path"; import fs from "fs/promises"; import EventEmitter from "events"; import { createWriteStream } from "fs"; -import { FindCursor } from "mongodb"; +import { AggregationCursor, FindCursor } from "mongodb"; import { EJSON, EJSONOptions, ObjectId } from "bson"; import { Transform } from "stream"; import { pipeline } from "stream/promises"; @@ -154,7 +154,7 @@ export class ExportsManager extends EventEmitter { exportTitle, jsonExportFormat, }: { - input: FindCursor; + input: FindCursor | AggregationCursor; exportName: string; exportTitle: string; jsonExportFormat: JSONExportFormat; @@ -194,7 +194,7 @@ export class ExportsManager extends EventEmitter { jsonExportFormat, inProgressExport, }: { - input: FindCursor; + input: FindCursor | AggregationCursor; jsonExportFormat: JSONExportFormat; inProgressExport: InProgressExport; }): Promise { diff --git a/src/tools/mongodb/read/export.ts b/src/tools/mongodb/read/export.ts index 9eaacba2f..2a6097c8e 100644 --- a/src/tools/mongodb/read/export.ts +++ b/src/tools/mongodb/read/export.ts @@ -1,19 +1,42 @@ import z from "zod"; import { ObjectId } from "bson"; +import { AggregationCursor, FindCursor } from "mongodb"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { OperationType, ToolArgs } from "../../tool.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { FindArgs } from "./find.js"; import { jsonExportFormat } from "../../../common/exportsManager.js"; +import { AggregateArgs } from "./aggregate.js"; export class ExportTool extends MongoDBToolBase { public name = "export"; protected description = "Export a collection data or query results in the specified EJSON format."; protected argsShape = { - exportTitle: z.string().describe("A short description to uniquely identify the export."), ...DbOperationArgs, - ...FindArgs, - limit: z.number().optional().describe("The maximum number of documents to return"), + exportTitle: z.string().describe("A short description to uniquely identify the export."), + exportTarget: z + .array( + z.discriminatedUnion("name", [ + z.object({ + name: z + .literal("find") + .describe("The literal name 'find' to represent a find cursor as target."), + arguments: z + .object({ + ...FindArgs, + limit: FindArgs.limit.removeDefault(), + }) + .describe("The arguments for 'find' operation."), + }), + z.object({ + name: z + .literal("aggregate") + .describe("The literal name 'aggregate' to represent an aggregation cursor as target."), + arguments: z.object(AggregateArgs).describe("The arguments for 'aggregate' operation."), + }), + ]) + ) + .describe("The export target along with its arguments."), jsonExportFormat: jsonExportFormat .default("relaxed") .describe( @@ -30,24 +53,38 @@ export class ExportTool extends MongoDBToolBase { database, collection, jsonExportFormat, - filter, - projection, - sort, - limit, exportTitle, + exportTarget: target, }: ToolArgs): Promise { const provider = await this.ensureConnected(); - const findCursor = provider.find(database, collection, filter ?? {}, { - projection, - sort, - limit, - promoteValues: false, - bsonRegExp: true, - }); + const exportTarget = target[0]; + if (!exportTarget) { + throw new Error("Export target not provided. Expected one of the following: `aggregate`, `find`"); + } + + let cursor: FindCursor | AggregationCursor; + if (exportTarget.name === "find") { + const { filter, projection, sort, limit } = exportTarget.arguments; + cursor = provider.find(database, collection, filter ?? {}, { + projection, + sort, + limit, + promoteValues: false, + bsonRegExp: true, + }); + } else { + const { pipeline } = exportTarget.arguments; + cursor = provider.aggregate(database, collection, pipeline, { + promoteValues: false, + bsonRegExp: true, + allowDiskUse: true, + }); + } + const exportName = `${database}.${collection}.${new ObjectId().toString()}.json`; const { exportURI, exportPath } = await this.session.exportsManager.createJSONExport({ - input: findCursor, + input: cursor, exportName, exportTitle: exportTitle || diff --git a/tests/accuracy/export.test.ts b/tests/accuracy/export.test.ts index 9e1f0cff5..5b2624171 100644 --- a/tests/accuracy/export.test.ts +++ b/tests/accuracy/export.test.ts @@ -10,8 +10,13 @@ describeAccuracyTests([ parameters: { database: "mflix", collection: "movies", - filter: Matcher.emptyObjectOrUndefined, - limit: Matcher.undefined, + exportTitle: Matcher.string(), + exportTarget: [ + { + name: "find", + arguments: {}, + }, + ], }, }, ], @@ -24,9 +29,17 @@ describeAccuracyTests([ parameters: { database: "mflix", collection: "movies", - filter: { - runtime: { $lt: 100 }, - }, + exportTitle: Matcher.string(), + exportTarget: [ + { + name: "find", + arguments: { + filter: { + runtime: { $lt: 100 }, + }, + }, + }, + ], }, }, ], @@ -39,14 +52,22 @@ describeAccuracyTests([ parameters: { database: "mflix", collection: "movies", - projection: { - title: 1, - _id: Matcher.anyOf( - Matcher.undefined, - Matcher.number((value) => value === 0) - ), - }, - filter: Matcher.emptyObjectOrUndefined, + exportTitle: Matcher.string(), + exportTarget: [ + { + name: "find", + arguments: { + projection: { + title: 1, + _id: Matcher.anyOf( + Matcher.undefined, + Matcher.number((value) => value === 0) + ), + }, + filter: Matcher.emptyObjectOrUndefined, + }, + }, + ], }, }, ], @@ -59,9 +80,47 @@ describeAccuracyTests([ parameters: { database: "mflix", collection: "movies", - filter: { genres: "Horror" }, - sort: { runtime: 1 }, - limit: 2, + exportTitle: Matcher.string(), + exportTarget: [ + { + name: "find", + arguments: { + filter: { genres: "Horror" }, + sort: { runtime: 1 }, + limit: 2, + }, + }, + ], + }, + }, + ], + }, + { + prompt: "Export an aggregation that groups all movie titles by the field release_year from mflix.movies", + expectedToolCalls: [ + { + toolName: "export", + parameters: { + database: "mflix", + collection: "movies", + exportTitle: Matcher.string(), + exportTarget: [ + { + name: "aggregate", + arguments: { + pipeline: [ + { + $group: { + _id: "$release_year", + titles: { + $push: "$title", + }, + }, + }, + ], + }, + }, + ], }, }, ], diff --git a/tests/integration/resources/exportedData.test.ts b/tests/integration/resources/exportedData.test.ts index 94710d87f..3112fe878 100644 --- a/tests/integration/resources/exportedData.test.ts +++ b/tests/integration/resources/exportedData.test.ts @@ -65,7 +65,12 @@ describeWithMongoDB( await integration.connectMcpClient(); const exportResponse = await integration.mcpClient().callTool({ name: "export", - arguments: { database: "db", collection: "coll", exportTitle: "Export for db.coll" }, + arguments: { + database: "db", + collection: "coll", + exportTitle: "Export for db.coll", + exportTarget: [{ name: "find", arguments: {} }], + }, }); const exportedResourceURI = (exportResponse as CallToolResult).content.find( @@ -99,7 +104,12 @@ describeWithMongoDB( await integration.connectMcpClient(); const exportResponse = await integration.mcpClient().callTool({ name: "export", - arguments: { database: "db", collection: "coll", exportTitle: "Export for db.coll" }, + arguments: { + database: "db", + collection: "coll", + exportTitle: "Export for db.coll", + exportTarget: [{ name: "find", arguments: {} }], + }, }); const content = exportResponse.content as CallToolResult["content"]; const exportURI = contentWithResourceURILink(content)?.uri as string; @@ -122,7 +132,12 @@ describeWithMongoDB( await integration.connectMcpClient(); const exportResponse = await integration.mcpClient().callTool({ name: "export", - arguments: { database: "big", collection: "coll", exportTitle: "Export for big.coll" }, + arguments: { + database: "big", + collection: "coll", + exportTitle: "Export for big.coll", + exportTarget: [{ name: "find", arguments: {} }], + }, }); const content = exportResponse.content as CallToolResult["content"]; const exportURI = contentWithResourceURILink(content)?.uri as string; diff --git a/tests/integration/tools/mongodb/read/export.test.ts b/tests/integration/tools/mongodb/read/export.test.ts index 343f3ef45..f02460f01 100644 --- a/tests/integration/tools/mongodb/read/export.test.ts +++ b/tests/integration/tools/mongodb/read/export.test.ts @@ -62,12 +62,6 @@ describeWithMongoDB( type: "string", required: true, }, - { - name: "filter", - description: "The query filter, matching the syntax of the query argument of db.collection.find()", - type: "object", - required: false, - }, { name: "jsonExportFormat", description: [ @@ -79,24 +73,10 @@ describeWithMongoDB( required: false, }, { - name: "limit", - description: "The maximum number of documents to return", - type: "number", - required: false, - }, - { - name: "projection", - description: - "The projection, matching the syntax of the projection argument of db.collection.find()", - type: "object", - required: false, - }, - { - name: "sort", - description: - "A document, describing the sort order, matching the syntax of the sort argument of cursor.sort(). The keys of the object are the fields to sort on, while the values are the sort directions (1 for ascending, -1 for descending).", - type: "object", - required: false, + name: "exportTarget", + type: "array", + description: "The export target along with its arguments.", + required: true, }, ] ); @@ -126,6 +106,14 @@ describeWithMongoDB( database: "non-existent", collection: "foos", exportTitle: "Export for non-existent.foos", + exportTarget: [ + { + name: "find", + arguments: { + filter: {}, + }, + }, + ], }, }); const content = response.content as CallToolResult["content"]; @@ -165,6 +153,14 @@ describeWithMongoDB( database: integration.randomDbName(), collection: "foo", exportTitle: `Export for ${integration.randomDbName()}.foo`, + exportTarget: [ + { + name: "find", + arguments: { + filter: {}, + }, + }, + ], }, }); const content = response.content as CallToolResult["content"]; @@ -192,8 +188,15 @@ describeWithMongoDB( arguments: { database: integration.randomDbName(), collection: "foo", - filter: { name: "foo" }, exportTitle: `Export for ${integration.randomDbName()}.foo`, + exportTarget: [ + { + name: "find", + arguments: { + filter: { name: "foo" }, + }, + }, + ], }, }); const content = response.content as CallToolResult["content"]; @@ -220,8 +223,16 @@ describeWithMongoDB( arguments: { database: integration.randomDbName(), collection: "foo", - limit: 1, exportTitle: `Export for ${integration.randomDbName()}.foo`, + exportTarget: [ + { + name: "find", + arguments: { + filter: {}, + limit: 1, + }, + }, + ], }, }); const content = response.content as CallToolResult["content"]; @@ -248,9 +259,17 @@ describeWithMongoDB( arguments: { database: integration.randomDbName(), collection: "foo", - limit: 1, - sort: { longNumber: 1 }, exportTitle: `Export for ${integration.randomDbName()}.foo`, + exportTarget: [ + { + name: "find", + arguments: { + filter: {}, + limit: 1, + sort: { longNumber: 1 }, + }, + }, + ], }, }); const content = response.content as CallToolResult["content"]; @@ -277,9 +296,17 @@ describeWithMongoDB( arguments: { database: integration.randomDbName(), collection: "foo", - limit: 1, - projection: { _id: 0, name: 1 }, exportTitle: `Export for ${integration.randomDbName()}.foo`, + exportTarget: [ + { + name: "find", + arguments: { + filter: {}, + limit: 1, + projection: { _id: 0, name: 1 }, + }, + }, + ], }, }); const content = response.content as CallToolResult["content"]; @@ -309,10 +336,18 @@ describeWithMongoDB( arguments: { database: integration.randomDbName(), collection: "foo", - limit: 1, - projection: { _id: 0 }, jsonExportFormat: "relaxed", exportTitle: `Export for ${integration.randomDbName()}.foo`, + exportTarget: [ + { + name: "find", + arguments: { + filter: {}, + limit: 1, + projection: { _id: 0 }, + }, + }, + ], }, }); const content = response.content as CallToolResult["content"]; @@ -343,10 +378,18 @@ describeWithMongoDB( arguments: { database: integration.randomDbName(), collection: "foo", - limit: 1, - projection: { _id: 0 }, jsonExportFormat: "canonical", exportTitle: `Export for ${integration.randomDbName()}.foo`, + exportTarget: [ + { + name: "find", + arguments: { + filter: {}, + limit: 1, + projection: { _id: 0 }, + }, + }, + ], }, }); const content = response.content as CallToolResult["content"]; @@ -371,6 +414,48 @@ describeWithMongoDB( }, ]); }); + + it("should allow exporting an aggregation", async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "export", + arguments: { + database: integration.randomDbName(), + collection: "foo", + exportTitle: `Export for ${integration.randomDbName()}.foo`, + exportTarget: [ + { + name: "aggregate", + arguments: { + pipeline: [ + { + $match: {}, + }, + { + $limit: 1, + }, + ], + }, + }, + ], + }, + }); + const content = response.content as CallToolResult["content"]; + const exportURI = contentWithResourceURILink(content)?.uri as string; + await resourceChangedNotification(integration.mcpClient(), exportURI); + + const localPathPart = contentWithExportPath(content); + expect(localPathPart).toBeDefined(); + const [, localPath] = /"(.*)"/.exec(String(localPathPart?.text)) ?? []; + expect(localPath).toBeDefined(); + + const exportedContent = JSON.parse(await fs.readFile(localPath as string, "utf8")) as Record< + string, + unknown + >[]; + expect(exportedContent).toHaveLength(1); + expect(exportedContent[0]?.name).toEqual("foo"); + }); }); }, () => userConfig