11import { z } from "zod" ;
2+ import type { AggregationCursor } from "mongodb" ;
23import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js" ;
4+ import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver" ;
35import { DbOperationArgs , MongoDBToolBase } from "../mongodbTool.js" ;
46import type { ToolArgs , OperationType } from "../../tool.js" ;
57import { formatUntrustedData } from "../../tool.js" ;
68import { checkIndexUsage } from "../../../helpers/indexCheck.js" ;
7- import { EJSON } from "bson" ;
9+ import { type Document , EJSON } from "bson" ;
810import { ErrorCodes , MongoDBError } from "../../../common/errors.js" ;
11+ import { iterateCursorUntilMaxBytes } from "../../../helpers/iterateCursor.js" ;
12+ import { operationWithFallback } from "../../../helpers/operationWithFallback.js" ;
13+
14+ /**
15+ * A cap for the maxTimeMS used for counting resulting documents of an
16+ * aggregation.
17+ */
18+ const AGG_COUNT_MAX_TIME_MS_CAP = 60_000 ;
919
1020export const AggregateArgs = {
1121 pipeline : z . array ( z . object ( { } ) . passthrough ( ) ) . describe ( "An array of aggregation stages to execute" ) ,
@@ -25,27 +35,43 @@ export class AggregateTool extends MongoDBToolBase {
2535 collection,
2636 pipeline,
2737 } : ToolArgs < typeof this . argsShape > ) : Promise < CallToolResult > {
28- const provider = await this . ensureConnected ( ) ;
38+ let aggregationCursor : AggregationCursor | undefined ;
39+ try {
40+ const provider = await this . ensureConnected ( ) ;
2941
30- this . assertOnlyUsesPermittedStages ( pipeline ) ;
42+ this . assertOnlyUsesPermittedStages ( pipeline ) ;
3143
32- // Check if aggregate operation uses an index if enabled
33- if ( this . config . indexCheck ) {
34- await checkIndexUsage ( provider , database , collection , "aggregate" , async ( ) => {
35- return provider
36- . aggregate ( database , collection , pipeline , { } , { writeConcern : undefined } )
37- . explain ( "queryPlanner" ) ;
38- } ) ;
39- }
44+ // Check if aggregate operation uses an index if enabled
45+ if ( this . config . indexCheck ) {
46+ await checkIndexUsage ( provider , database , collection , "aggregate" , async ( ) => {
47+ return provider
48+ . aggregate ( database , collection , pipeline , { } , { writeConcern : undefined } )
49+ . explain ( "queryPlanner" ) ;
50+ } ) ;
51+ }
52+
53+ const cappedResultsPipeline = [ ...pipeline , { $limit : this . config . maxDocumentsPerQuery } ] ;
54+ aggregationCursor = provider
55+ . aggregate ( database , collection , cappedResultsPipeline )
56+ . batchSize ( this . config . maxDocumentsPerQuery ) ;
4057
41- const documents = await provider . aggregate ( database , collection , pipeline ) . toArray ( ) ;
58+ const [ totalDocuments , documents ] = await Promise . all ( [
59+ this . countAggregationResultDocuments ( { provider, database, collection, pipeline } ) ,
60+ iterateCursorUntilMaxBytes ( aggregationCursor , this . config . maxBytesPerQuery ) ,
61+ ] ) ;
4262
43- return {
44- content : formatUntrustedData (
45- `The aggregation resulted in ${ documents . length } documents.` ,
46- documents . length > 0 ? EJSON . stringify ( documents ) : undefined
47- ) ,
48- } ;
63+ const messageDescription = `\
64+ The aggregation resulted in ${ totalDocuments === undefined ? "indeterminable number of" : totalDocuments } documents. \
65+ Returning ${ documents . length } documents while respecting the applied limits. \
66+ Note to LLM: If entire aggregation result is needed then use "export" tool to export the aggregation results.\
67+ ` ;
68+
69+ return {
70+ content : formatUntrustedData ( messageDescription , EJSON . stringify ( documents ) ) ,
71+ } ;
72+ } finally {
73+ await aggregationCursor ?. close ( ) ;
74+ }
4975 }
5076
5177 private assertOnlyUsesPermittedStages ( pipeline : Record < string , unknown > [ ] ) : void {
@@ -62,4 +88,35 @@ export class AggregateTool extends MongoDBToolBase {
6288 }
6389 }
6490 }
91+
92+ private async countAggregationResultDocuments ( {
93+ provider,
94+ database,
95+ collection,
96+ pipeline,
97+ } : {
98+ provider : NodeDriverServiceProvider ;
99+ database : string ;
100+ collection : string ;
101+ pipeline : Document [ ] ;
102+ } ) : Promise < number | undefined > {
103+ const resultsCountAggregation = [ ...pipeline , { $count : "totalDocuments" } ] ;
104+ return await operationWithFallback ( async ( ) : Promise < number | undefined > => {
105+ const aggregationResults = await provider
106+ . aggregate ( database , collection , resultsCountAggregation )
107+ . maxTimeMS ( AGG_COUNT_MAX_TIME_MS_CAP )
108+ . toArray ( ) ;
109+
110+ const documentWithCount : unknown = aggregationResults . length === 1 ? aggregationResults [ 0 ] : undefined ;
111+ const totalDocuments =
112+ documentWithCount &&
113+ typeof documentWithCount === "object" &&
114+ "totalDocuments" in documentWithCount &&
115+ typeof documentWithCount . totalDocuments === "number"
116+ ? documentWithCount . totalDocuments
117+ : undefined ;
118+
119+ return totalDocuments ;
120+ } , undefined ) ;
121+ }
65122}
0 commit comments