1
1
import { z } from "zod" ;
2
+ import type { AggregationCursor } from "mongodb" ;
2
3
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js" ;
4
+ import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver" ;
3
5
import { DbOperationArgs , MongoDBToolBase } from "../mongodbTool.js" ;
4
6
import type { ToolArgs , OperationType } from "../../tool.js" ;
5
7
import { formatUntrustedData } from "../../tool.js" ;
6
8
import { checkIndexUsage } from "../../../helpers/indexCheck.js" ;
7
- import { EJSON } from "bson" ;
9
+ import { type Document , EJSON } from "bson" ;
8
10
import { ErrorCodes , MongoDBError } from "../../../common/errors.js" ;
11
+ import { iterateCursorUntilMaxBytes } from "../../../helpers/iterateCursor.js" ;
12
+ import { operationWithFallback } from "../../../helpers/operationWithFallback.js" ;
13
+
14
+ /**
15
+ * A cap for the maxTimeMS used for counting resulting documents of an
16
+ * aggregation.
17
+ */
18
+ const AGG_COUNT_MAX_TIME_MS_CAP = 60_000 ;
9
19
10
20
export const AggregateArgs = {
11
21
pipeline : z . array ( z . object ( { } ) . passthrough ( ) ) . describe ( "An array of aggregation stages to execute" ) ,
@@ -25,27 +35,43 @@ export class AggregateTool extends MongoDBToolBase {
25
35
collection,
26
36
pipeline,
27
37
} : ToolArgs < typeof this . argsShape > ) : Promise < CallToolResult > {
28
- const provider = await this . ensureConnected ( ) ;
38
+ let aggregationCursor : AggregationCursor | undefined ;
39
+ try {
40
+ const provider = await this . ensureConnected ( ) ;
29
41
30
- this . assertOnlyUsesPermittedStages ( pipeline ) ;
42
+ this . assertOnlyUsesPermittedStages ( pipeline ) ;
31
43
32
- // Check if aggregate operation uses an index if enabled
33
- if ( this . config . indexCheck ) {
34
- await checkIndexUsage ( provider , database , collection , "aggregate" , async ( ) => {
35
- return provider
36
- . aggregate ( database , collection , pipeline , { } , { writeConcern : undefined } )
37
- . explain ( "queryPlanner" ) ;
38
- } ) ;
39
- }
44
+ // Check if aggregate operation uses an index if enabled
45
+ if ( this . config . indexCheck ) {
46
+ await checkIndexUsage ( provider , database , collection , "aggregate" , async ( ) => {
47
+ return provider
48
+ . aggregate ( database , collection , pipeline , { } , { writeConcern : undefined } )
49
+ . explain ( "queryPlanner" ) ;
50
+ } ) ;
51
+ }
52
+
53
+ const cappedResultsPipeline = [ ...pipeline , { $limit : this . config . maxDocumentsPerQuery } ] ;
54
+ aggregationCursor = provider
55
+ . aggregate ( database , collection , cappedResultsPipeline )
56
+ . batchSize ( this . config . maxDocumentsPerQuery ) ;
40
57
41
- const documents = await provider . aggregate ( database , collection , pipeline ) . toArray ( ) ;
58
+ const [ totalDocuments , documents ] = await Promise . all ( [
59
+ this . countAggregationResultDocuments ( { provider, database, collection, pipeline } ) ,
60
+ iterateCursorUntilMaxBytes ( aggregationCursor , this . config . maxBytesPerQuery ) ,
61
+ ] ) ;
42
62
43
- return {
44
- content : formatUntrustedData (
45
- `The aggregation resulted in ${ documents . length } documents.` ,
46
- documents . length > 0 ? EJSON . stringify ( documents ) : undefined
47
- ) ,
48
- } ;
63
+ const messageDescription = `\
64
+ The aggregation resulted in ${ totalDocuments === undefined ? "indeterminable number of" : totalDocuments } documents. \
65
+ Returning ${ documents . length } documents while respecting the applied limits. \
66
+ Note to LLM: If entire aggregation result is needed then use "export" tool to export the aggregation results.\
67
+ ` ;
68
+
69
+ return {
70
+ content : formatUntrustedData ( messageDescription , EJSON . stringify ( documents ) ) ,
71
+ } ;
72
+ } finally {
73
+ await aggregationCursor ?. close ( ) ;
74
+ }
49
75
}
50
76
51
77
private assertOnlyUsesPermittedStages ( pipeline : Record < string , unknown > [ ] ) : void {
@@ -62,4 +88,35 @@ export class AggregateTool extends MongoDBToolBase {
62
88
}
63
89
}
64
90
}
91
+
92
+ private async countAggregationResultDocuments ( {
93
+ provider,
94
+ database,
95
+ collection,
96
+ pipeline,
97
+ } : {
98
+ provider : NodeDriverServiceProvider ;
99
+ database : string ;
100
+ collection : string ;
101
+ pipeline : Document [ ] ;
102
+ } ) : Promise < number | undefined > {
103
+ const resultsCountAggregation = [ ...pipeline , { $count : "totalDocuments" } ] ;
104
+ return await operationWithFallback ( async ( ) : Promise < number | undefined > => {
105
+ const aggregationResults = await provider
106
+ . aggregate ( database , collection , resultsCountAggregation )
107
+ . maxTimeMS ( AGG_COUNT_MAX_TIME_MS_CAP )
108
+ . toArray ( ) ;
109
+
110
+ const documentWithCount : unknown = aggregationResults . length === 1 ? aggregationResults [ 0 ] : undefined ;
111
+ const totalDocuments =
112
+ documentWithCount &&
113
+ typeof documentWithCount === "object" &&
114
+ "totalDocuments" in documentWithCount &&
115
+ typeof documentWithCount . totalDocuments === "number"
116
+ ? documentWithCount . totalDocuments
117
+ : undefined ;
118
+
119
+ return totalDocuments ;
120
+ } , undefined ) ;
121
+ }
65
122
}
0 commit comments