Skip to content

Commit 042aa1d

Browse files
fix: add guards against possible memory overflow
Targets find and aggregate tool and does the following to avoid the memory overflow possibility: 1. Adds a configurable limit to restrict the number of documents fetched per query / aggregation. 2. Adds an iterator that keeps track of bytes consumed in memory by the retrieved documents and cuts off the iteration when there is a possibility of overflow. The overflow is based on configured maxBytesPerQuery parameter which defaults to 1MB.
1 parent 76aa332 commit 042aa1d

File tree

7 files changed

+269
-32
lines changed

7 files changed

+269
-32
lines changed

src/common/config.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ export interface UserConfig extends CliOptions {
161161
loggers: Array<"stderr" | "disk" | "mcp">;
162162
idleTimeoutMs: number;
163163
notificationTimeoutMs: number;
164+
maxDocumentsPerQuery: number;
165+
maxBytesPerQuery: number;
164166
}
165167

166168
export const defaultUserConfig: UserConfig = {
@@ -180,6 +182,8 @@ export const defaultUserConfig: UserConfig = {
180182
idleTimeoutMs: 600000, // 10 minutes
181183
notificationTimeoutMs: 540000, // 9 minutes
182184
httpHeaders: {},
185+
maxDocumentsPerQuery: 50,
186+
maxBytesPerQuery: 1 * 1000 * 1000, // 1 mb
183187
};
184188

185189
export const config = setupUserConfig({

src/helpers/iterateCursor.ts

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import { calculateObjectSize } from "bson";
2+
import type { AggregationCursor, FindCursor } from "mongodb";
3+
4+
/**
5+
* This function attempts to put a guard rail against accidental memory over
6+
* flow on the MCP server.
7+
*
8+
* The cursor is iterated until we can predict that fetching next doc won't
9+
* exceed the maxBytesPerQuery limit.
10+
*/
11+
export async function iterateCursorUntilMaxBytes(
12+
cursor: FindCursor<unknown> | AggregationCursor<unknown>,
13+
maxBytesPerQuery: number
14+
): Promise<unknown[]> {
15+
let biggestDocSizeSoFar = 0;
16+
let totalBytes = 0;
17+
const bufferedDocuments: unknown[] = [];
18+
while (true) {
19+
if (totalBytes + biggestDocSizeSoFar >= maxBytesPerQuery) {
20+
break;
21+
}
22+
23+
const nextDocument = await cursor.tryNext();
24+
if (!nextDocument) {
25+
break;
26+
}
27+
28+
const nextDocumentSize = calculateObjectSize(nextDocument);
29+
if (totalBytes + nextDocumentSize >= maxBytesPerQuery) {
30+
break;
31+
}
32+
33+
totalBytes += nextDocumentSize;
34+
biggestDocSizeSoFar = Math.max(biggestDocSizeSoFar, nextDocumentSize);
35+
bufferedDocuments.push(nextDocument);
36+
}
37+
38+
return bufferedDocuments;
39+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
type OperationCallback<OperationResult> = () => Promise<OperationResult>;
2+
3+
export async function operationWithFallback<OperationResult, FallbackValue>(
4+
performOperation: OperationCallback<OperationResult>,
5+
fallback: FallbackValue
6+
): Promise<OperationResult | FallbackValue> {
7+
try {
8+
return await performOperation();
9+
} catch {
10+
return fallback;
11+
}
12+
}

src/tools/mongodb/read/aggregate.ts

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
import { z } from "zod";
2+
import type { AggregationCursor } from "mongodb";
23
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
4+
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
35
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
46
import type { ToolArgs, OperationType } from "../../tool.js";
57
import { formatUntrustedData } from "../../tool.js";
68
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
7-
import { EJSON } from "bson";
9+
import { type Document, EJSON } from "bson";
810
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
11+
import { iterateCursorUntilMaxBytes } from "../../../helpers/iterateCursor.js";
12+
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
13+
14+
/**
15+
* A cap for the maxTimeMS used for counting resulting documents of an
16+
* aggregation.
17+
*/
18+
const AGG_COUNT_MAX_TIME_MS_CAP = 60_000;
919

1020
export const AggregateArgs = {
1121
pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"),
@@ -25,27 +35,43 @@ export class AggregateTool extends MongoDBToolBase {
2535
collection,
2636
pipeline,
2737
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
28-
const provider = await this.ensureConnected();
38+
let aggregationCursor: AggregationCursor | undefined;
39+
try {
40+
const provider = await this.ensureConnected();
2941

30-
this.assertOnlyUsesPermittedStages(pipeline);
42+
this.assertOnlyUsesPermittedStages(pipeline);
3143

32-
// Check if aggregate operation uses an index if enabled
33-
if (this.config.indexCheck) {
34-
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
35-
return provider
36-
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
37-
.explain("queryPlanner");
38-
});
39-
}
44+
// Check if aggregate operation uses an index if enabled
45+
if (this.config.indexCheck) {
46+
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
47+
return provider
48+
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
49+
.explain("queryPlanner");
50+
});
51+
}
52+
53+
const cappedResultsPipeline = [...pipeline, { $limit: this.config.maxDocumentsPerQuery }];
54+
aggregationCursor = provider
55+
.aggregate(database, collection, cappedResultsPipeline)
56+
.batchSize(this.config.maxDocumentsPerQuery);
4057

41-
const documents = await provider.aggregate(database, collection, pipeline).toArray();
58+
const [totalDocuments, documents] = await Promise.all([
59+
this.countAggregationResultDocuments({ provider, database, collection, pipeline }),
60+
iterateCursorUntilMaxBytes(aggregationCursor, this.config.maxBytesPerQuery),
61+
]);
4262

43-
return {
44-
content: formatUntrustedData(
45-
`The aggregation resulted in ${documents.length} documents.`,
46-
documents.length > 0 ? EJSON.stringify(documents) : undefined
47-
),
48-
};
63+
const messageDescription = `\
64+
The aggregation resulted in ${totalDocuments === undefined ? "indeterminable number of" : totalDocuments} documents. \
65+
Returning ${documents.length} documents while respecting the applied limits. \
66+
Note to LLM: If entire aggregation result is needed then use "export" tool to export the aggregation results.\
67+
`;
68+
69+
return {
70+
content: formatUntrustedData(messageDescription, EJSON.stringify(documents)),
71+
};
72+
} finally {
73+
await aggregationCursor?.close();
74+
}
4975
}
5076

5177
private assertOnlyUsesPermittedStages(pipeline: Record<string, unknown>[]): void {
@@ -62,4 +88,35 @@ export class AggregateTool extends MongoDBToolBase {
6288
}
6389
}
6490
}
91+
92+
private async countAggregationResultDocuments({
93+
provider,
94+
database,
95+
collection,
96+
pipeline,
97+
}: {
98+
provider: NodeDriverServiceProvider;
99+
database: string;
100+
collection: string;
101+
pipeline: Document[];
102+
}): Promise<number | undefined> {
103+
const resultsCountAggregation = [...pipeline, { $count: "totalDocuments" }];
104+
return await operationWithFallback(async (): Promise<number | undefined> => {
105+
const aggregationResults = await provider
106+
.aggregate(database, collection, resultsCountAggregation)
107+
.maxTimeMS(AGG_COUNT_MAX_TIME_MS_CAP)
108+
.toArray();
109+
110+
const documentWithCount: unknown = aggregationResults.length === 1 ? aggregationResults[0] : undefined;
111+
const totalDocuments =
112+
documentWithCount &&
113+
typeof documentWithCount === "object" &&
114+
"totalDocuments" in documentWithCount &&
115+
typeof documentWithCount.totalDocuments === "number"
116+
? documentWithCount.totalDocuments
117+
: undefined;
118+
119+
return totalDocuments;
120+
}, undefined);
121+
}
65122
}

src/tools/mongodb/read/find.ts

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,20 @@ import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
33
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
44
import type { ToolArgs, OperationType } from "../../tool.js";
55
import { formatUntrustedData } from "../../tool.js";
6-
import type { SortDirection } from "mongodb";
6+
import type { FindCursor, SortDirection } from "mongodb";
77
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
88
import { EJSON } from "bson";
9+
import { iterateCursorUntilMaxBytes } from "../../../helpers/iterateCursor.js";
10+
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
11+
12+
/**
13+
* A cap for the maxTimeMS used for FindCursor.countDocuments.
14+
*
15+
* The number is relatively smaller because we expect the count documents query
16+
* to be finished sooner if not by the time the batch of documents is retrieved
17+
* so that count documents query don't hold the final response back.
18+
*/
19+
const QUERY_COUNT_MAX_TIME_MS_CAP = 10_000;
920

1021
export const FindArgs = {
1122
filter: z
@@ -45,22 +56,50 @@ export class FindTool extends MongoDBToolBase {
4556
limit,
4657
sort,
4758
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
48-
const provider = await this.ensureConnected();
59+
let findCursor: FindCursor<unknown> | undefined;
60+
try {
61+
const provider = await this.ensureConnected();
62+
63+
// Check if find operation uses an index if enabled
64+
if (this.config.indexCheck) {
65+
await checkIndexUsage(provider, database, collection, "find", async () => {
66+
return provider
67+
.find(database, collection, filter, { projection, limit, sort })
68+
.explain("queryPlanner");
69+
});
70+
}
4971

50-
// Check if find operation uses an index if enabled
51-
if (this.config.indexCheck) {
52-
await checkIndexUsage(provider, database, collection, "find", async () => {
53-
return provider.find(database, collection, filter, { projection, limit, sort }).explain("queryPlanner");
72+
const appliedLimit = Math.min(limit, this.config.maxDocumentsPerQuery);
73+
findCursor = provider.find(database, collection, filter, {
74+
projection,
75+
limit: appliedLimit,
76+
sort,
77+
batchSize: appliedLimit,
5478
});
55-
}
5679

57-
const documents = await provider.find(database, collection, filter, { projection, limit, sort }).toArray();
80+
const [queryResultsCount, documents] = await Promise.all([
81+
operationWithFallback(
82+
() =>
83+
provider.countDocuments(database, collection, filter, {
84+
limit,
85+
maxTimeMS: QUERY_COUNT_MAX_TIME_MS_CAP,
86+
}),
87+
undefined
88+
),
89+
iterateCursorUntilMaxBytes(findCursor, this.config.maxBytesPerQuery),
90+
]);
5891

59-
return {
60-
content: formatUntrustedData(
61-
`Found ${documents.length} documents in the collection "${collection}".`,
62-
documents.length > 0 ? EJSON.stringify(documents) : undefined
63-
),
64-
};
92+
const messageDescription = `\
93+
Query on collection "${collection}" resulted in ${queryResultsCount === undefined ? "indeterminable number of" : queryResultsCount} documents. \
94+
Returning ${documents.length} documents while respecting the applied limits. \
95+
Note to LLM: If entire query result is needed then use "export" tool to export the query results.\
96+
`;
97+
98+
return {
99+
content: formatUntrustedData(messageDescription, EJSON.stringify(documents)),
100+
};
101+
} finally {
102+
await findCursor?.close();
103+
}
65104
}
66105
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import { describe, it, expect, vi } from "vitest";
2+
import type { FindCursor } from "mongodb";
3+
import { calculateObjectSize } from "bson";
4+
import { iterateCursorUntilMaxBytes } from "../../../src/helpers/iterateCursor.js";
5+
6+
describe("iterateCursorUntilMaxBytes", () => {
7+
function createMockCursor(docs: unknown[]): FindCursor<unknown> {
8+
let idx = 0;
9+
return {
10+
tryNext: vi.fn(() => {
11+
if (idx < docs.length) {
12+
return Promise.resolve(docs[idx++]);
13+
}
14+
return Promise.resolve(null);
15+
}),
16+
} as unknown as FindCursor<unknown>;
17+
}
18+
19+
it("returns all docs if under maxBytesPerQuery", async () => {
20+
const docs = [{ a: 1 }, { b: 2 }];
21+
const cursor = createMockCursor(docs);
22+
const maxBytes = 10000;
23+
const result = await iterateCursorUntilMaxBytes(cursor, maxBytes);
24+
console.log("test result", result);
25+
expect(result).toEqual(docs);
26+
});
27+
28+
it("returns only docs that fit under maxBytesPerQuery", async () => {
29+
const doc1 = { a: "x".repeat(100) };
30+
const doc2 = { b: "y".repeat(1000) };
31+
const docs = [doc1, doc2];
32+
const cursor = createMockCursor(docs);
33+
const maxBytes = calculateObjectSize(doc1) + 10;
34+
const result = await iterateCursorUntilMaxBytes(cursor, maxBytes);
35+
expect(result).toEqual([doc1]);
36+
});
37+
38+
it("returns empty array if maxBytesPerQuery is smaller than even the first doc", async () => {
39+
const docs = [{ a: "x".repeat(100) }];
40+
const cursor = createMockCursor(docs);
41+
const result = await iterateCursorUntilMaxBytes(cursor, 10);
42+
expect(result).toEqual([]);
43+
});
44+
45+
it("handles empty cursor", async () => {
46+
const cursor = createMockCursor([]);
47+
const result = await iterateCursorUntilMaxBytes(cursor, 1000);
48+
expect(result).toEqual([]);
49+
});
50+
51+
it("does not include a doc that would overflow the max bytes allowed", async () => {
52+
const doc1 = { a: "x".repeat(10) };
53+
const doc2 = { b: "y".repeat(1000) };
54+
const docs = [doc1, doc2];
55+
const cursor = createMockCursor(docs);
56+
// Set maxBytes so that after doc1, biggestDocSizeSoFar would prevent fetching doc2
57+
const maxBytes = calculateObjectSize(doc1) + calculateObjectSize(doc2) - 1;
58+
const result = await iterateCursorUntilMaxBytes(cursor, maxBytes);
59+
// Should only include doc1, not doc2
60+
expect(result).toEqual([doc1]);
61+
});
62+
});
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import { describe, it, expect, vi } from "vitest";
2+
import { operationWithFallback } from "../../../src/helpers/operationWithFallback.js";
3+
4+
describe("operationWithFallback", () => {
5+
it("returns operation result when operation succeeds", async () => {
6+
const successfulOperation = vi.fn().mockResolvedValue("success");
7+
const fallbackValue = "fallback";
8+
9+
const result = await operationWithFallback(successfulOperation, fallbackValue);
10+
11+
expect(result).toBe("success");
12+
expect(successfulOperation).toHaveBeenCalledOnce();
13+
});
14+
15+
it("returns fallback value when operation throws an error", async () => {
16+
const failingOperation = vi.fn().mockRejectedValue(new Error("Operation failed"));
17+
const fallbackValue = "fallback";
18+
19+
const result = await operationWithFallback(failingOperation, fallbackValue);
20+
21+
expect(result).toBe("fallback");
22+
expect(failingOperation).toHaveBeenCalledOnce();
23+
});
24+
});

0 commit comments

Comments
 (0)