Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4e7d98c
fix: add guards against possible memory overflow
himanshusinghs Sep 9, 2025
250299b
chore: fix existing tests
himanshusinghs Sep 9, 2025
f1be251
chore: tests for the new behavior
himanshusinghs Sep 10, 2025
eff03a8
chore: add missing constants files
himanshusinghs Sep 10, 2025
7d670e8
Apply suggestion from @Copilot
himanshusinghs Sep 10, 2025
8e8c3aa
chore: minor typo
himanshusinghs Sep 10, 2025
6bd5638
fix: removes default limit from find tool schema
himanshusinghs Sep 10, 2025
937908b
chore: add an accuracy test for find tool
himanshusinghs Sep 10, 2025
9d9b9f8
chore: PR feedback
himanshusinghs Sep 10, 2025
13d8408
chore: abort cursor iteration on request timeouts
himanshusinghs Sep 10, 2025
f09b4f4
chore: use correct arg in agg tool
himanshusinghs Sep 10, 2025
7354562
chore: export tool matchers
himanshusinghs Sep 10, 2025
819ed01
chore: accuracy test fixes
himanshusinghs Sep 10, 2025
21f1d3e
Merge remote-tracking branch 'origin/main' into fix/MCP-21-avoid-memo…
himanshusinghs Sep 18, 2025
25e0367
chore: PR feedback about generous config defaults
himanshusinghs Sep 19, 2025
67d3ea8
Merge remote-tracking branch 'origin/main' into fix/MCP-21-avoid-memo…
himanshusinghs Sep 19, 2025
8601c05
chore: fix tests after merge
himanshusinghs Sep 19, 2025
955b7d8
chore: account for cursor close errors
himanshusinghs Sep 19, 2025
bca4bbe
chore: remove unnecessary call
himanshusinghs Sep 19, 2025
811474e
chore: revert export changes
himanshusinghs Sep 19, 2025
e3a87b3
chore: remove eager prediction of overflow
himanshusinghs Sep 19, 2025
e1c95bd
chore: initialise cursor variables
himanshusinghs Sep 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/common/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import levenshtein from "ts-levenshtein";

// From: https://github.com/mongodb-js/mongosh/blob/main/packages/cli-repl/src/arg-parser.ts
const OPTIONS = {
number: ["maxDocumentsPerQuery", "maxBytesPerQuery"],
string: [
"apiBaseUrl",
"apiClientId",
Expand Down Expand Up @@ -98,6 +99,7 @@ const OPTIONS = {

interface Options {
string: string[];
number: string[];
boolean: string[];
array: string[];
alias: Record<string, string>;
Expand All @@ -106,6 +108,7 @@ interface Options {

export const ALL_CONFIG_KEYS = new Set(
(OPTIONS.string as readonly string[])
.concat(OPTIONS.number)
.concat(OPTIONS.array)
.concat(OPTIONS.boolean)
.concat(Object.keys(OPTIONS.alias))
Expand Down Expand Up @@ -175,6 +178,8 @@ export interface UserConfig extends CliOptions {
loggers: Array<"stderr" | "disk" | "mcp">;
idleTimeoutMs: number;
notificationTimeoutMs: number;
maxDocumentsPerQuery: number;
maxBytesPerQuery: number;
atlasTemporaryDatabaseUserLifetimeMs: number;
}

Expand Down Expand Up @@ -202,6 +207,8 @@ export const defaultUserConfig: UserConfig = {
idleTimeoutMs: 10 * 60 * 1000, // 10 minutes
notificationTimeoutMs: 9 * 60 * 1000, // 9 minutes
httpHeaders: {},
maxDocumentsPerQuery: 100, // By default, we only fetch a maximum 100 documents per query / aggregation
maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation
atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours
};

Expand Down
1 change: 1 addition & 0 deletions src/common/logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export const LogId = {
mongodbConnectFailure: mongoLogId(1_004_001),
mongodbDisconnectFailure: mongoLogId(1_004_002),
mongodbConnectTry: mongoLogId(1_004_003),
mongodbCursorCloseError: mongoLogId(1_004_004),

toolUpdateFailure: mongoLogId(1_005_001),
resourceUpdateFailure: mongoLogId(1_005_002),
Expand Down
103 changes: 103 additions & 0 deletions src/helpers/collectCursorUntilMaxBytes.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import { calculateObjectSize } from "bson";
import type { AggregationCursor, FindCursor } from "mongodb";

export function getResponseBytesLimit(
toolResponseBytesLimit: number | undefined | null,
configuredMaxBytesPerQuery: unknown
): {
cappedBy: "config.maxBytesPerQuery" | "tool.responseBytesLimit" | undefined;
limit: number;
} {
const configuredLimit: number = parseInt(String(configuredMaxBytesPerQuery), 10);

// Setting configured maxBytesPerQuery to negative, zero or nullish is
// equivalent to disabling the max limit applied on documents
const configuredLimitIsNotApplicable = Number.isNaN(configuredLimit) || configuredLimit <= 0;

// It's possible to have tool parameter responseBytesLimit as null or
// negative values in which case we consider that no limit is to be
// applied from tool call perspective unless we have a maxBytesPerQuery
// configured.
const toolResponseLimitIsNotApplicable = typeof toolResponseBytesLimit !== "number" || toolResponseBytesLimit <= 0;

if (configuredLimitIsNotApplicable) {
return {
cappedBy: toolResponseLimitIsNotApplicable ? undefined : "tool.responseBytesLimit",
limit: toolResponseLimitIsNotApplicable ? 0 : toolResponseBytesLimit,
};
}

if (toolResponseLimitIsNotApplicable) {
return { cappedBy: "config.maxBytesPerQuery", limit: configuredLimit };
}

return {
cappedBy: configuredLimit < toolResponseBytesLimit ? "config.maxBytesPerQuery" : "tool.responseBytesLimit",
limit: Math.min(toolResponseBytesLimit, configuredLimit),
};
}

/**
* This function attempts to put a guard rail against accidental memory overflow
* on the MCP server.
*
* The cursor is iterated until we can predict that fetching next doc won't
* exceed the derived limit on number of bytes for the tool call. The derived
* limit takes into account the limit provided from the Tool's interface and the
* configured maxBytesPerQuery for the server.
*/
export async function collectCursorUntilMaxBytesLimit<T = unknown>({
cursor,
toolResponseBytesLimit,
configuredMaxBytesPerQuery,
abortSignal,
}: {
cursor: FindCursor<T> | AggregationCursor<T>;
toolResponseBytesLimit: number | undefined | null;
configuredMaxBytesPerQuery: unknown;
abortSignal?: AbortSignal;
}): Promise<{ cappedBy: "config.maxBytesPerQuery" | "tool.responseBytesLimit" | undefined; documents: T[] }> {
const { limit: maxBytesPerQuery, cappedBy } = getResponseBytesLimit(
toolResponseBytesLimit,
configuredMaxBytesPerQuery
);

// It's possible to have no limit on the cursor response by setting both the
// config.maxBytesPerQuery and tool.responseBytesLimit to nullish or
// negative values.
if (maxBytesPerQuery <= 0) {
return {
cappedBy,
documents: await cursor.toArray(),
};
}

let wasCapped: boolean = false;
let totalBytes = 0;
const bufferedDocuments: T[] = [];
while (true) {
if (abortSignal?.aborted) {
break;
}

// If the cursor is empty then there is nothing for us to do anymore.
const nextDocument = await cursor.tryNext();
if (!nextDocument) {
break;
}

const nextDocumentSize = calculateObjectSize(nextDocument);
if (totalBytes + nextDocumentSize >= maxBytesPerQuery) {
wasCapped = true;
break;
}

totalBytes += nextDocumentSize;
bufferedDocuments.push(nextDocument);
}

return {
cappedBy: wasCapped ? cappedBy : undefined,
documents: bufferedDocuments,
};
}
26 changes: 26 additions & 0 deletions src/helpers/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* A cap for the maxTimeMS used for FindCursor.countDocuments.
*
* The number is relatively smaller because we expect the count documents query
* to be finished sooner if not by the time the batch of documents is retrieved
* so that count documents query don't hold the final response back.
*/
export const QUERY_COUNT_MAX_TIME_MS_CAP: number = 10_000;

/**
* A cap for the maxTimeMS used for counting resulting documents of an
* aggregation.
*/
export const AGG_COUNT_MAX_TIME_MS_CAP: number = 60_000;

export const ONE_MB: number = 1 * 1024 * 1024;

/**
* A map of applied limit on cursors to a text that is supposed to be sent as
* response to LLM
*/
export const CURSOR_LIMITS_TO_LLM_TEXT = {
"config.maxDocumentsPerQuery": "server's configured - maxDocumentsPerQuery",
"config.maxBytesPerQuery": "server's configured - maxBytesPerQuery",
"tool.responseBytesLimit": "tool's parameter - responseBytesLimit",
} as const;
12 changes: 12 additions & 0 deletions src/helpers/operationWithFallback.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
type OperationCallback<OperationResult> = () => Promise<OperationResult>;

export async function operationWithFallback<OperationResult, FallbackValue>(
performOperation: OperationCallback<OperationResult>,
fallback: FallbackValue
): Promise<OperationResult | FallbackValue> {
try {
return await performOperation();
} catch {
return fallback;
}
}
159 changes: 135 additions & 24 deletions src/tools/mongodb/read/aggregate.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import { z } from "zod";
import type { AggregationCursor } from "mongodb";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import type { ToolArgs, OperationType } from "../../tool.js";
import type { ToolArgs, OperationType, ToolExecutionContext } from "../../tool.js";
import { formatUntrustedData } from "../../tool.js";
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
import { EJSON } from "bson";
import { type Document, EJSON } from "bson";
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
import { collectCursorUntilMaxBytesLimit } from "../../../helpers/collectCursorUntilMaxBytes.js";
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js";
import { zEJSON } from "../../args.js";
import { LogId } from "../../../common/logger.js";

export const AggregateArgs = {
pipeline: z.array(zEJSON()).describe("An array of aggregation stages to execute"),
responseBytesLimit: z.number().optional().default(ONE_MB).describe(`\
The maximum number of bytes to return in the response. This value is capped by the server’s configured maxBytesPerQuery and cannot be exceeded. \
Note to LLM: If the entire aggregation result is required, use the "export" tool instead of increasing this limit.\
`),
};

export class AggregateTool extends MongoDBToolBase {
Expand All @@ -21,32 +31,80 @@ export class AggregateTool extends MongoDBToolBase {
};
public operationType: OperationType = "read";

protected async execute({
database,
collection,
pipeline,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
protected async execute(
{ database, collection, pipeline, responseBytesLimit }: ToolArgs<typeof this.argsShape>,
{ signal }: ToolExecutionContext
): Promise<CallToolResult> {
let aggregationCursor: AggregationCursor | undefined;
try {
const provider = await this.ensureConnected();

this.assertOnlyUsesPermittedStages(pipeline);
this.assertOnlyUsesPermittedStages(pipeline);

// Check if aggregate operation uses an index if enabled
if (this.config.indexCheck) {
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
return provider
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
.explain("queryPlanner");
});
}
// Check if aggregate operation uses an index if enabled
if (this.config.indexCheck) {
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
return provider
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
.explain("queryPlanner");
});
}

const documents = await provider.aggregate(database, collection, pipeline).toArray();
const cappedResultsPipeline = [...pipeline];
if (this.config.maxDocumentsPerQuery > 0) {
cappedResultsPipeline.push({ $limit: this.config.maxDocumentsPerQuery });
}
aggregationCursor = provider.aggregate(database, collection, cappedResultsPipeline);

return {
content: formatUntrustedData(
`The aggregation resulted in ${documents.length} documents.`,
documents.length > 0 ? EJSON.stringify(documents) : undefined
),
};
const [totalDocuments, cursorResults] = await Promise.all([
this.countAggregationResultDocuments({ provider, database, collection, pipeline }),
collectCursorUntilMaxBytesLimit({
cursor: aggregationCursor,
configuredMaxBytesPerQuery: this.config.maxBytesPerQuery,
toolResponseBytesLimit: responseBytesLimit,
abortSignal: signal,
}),
]);

// If the total number of documents that the aggregation would've
// resulted in would be greater than the configured
// maxDocumentsPerQuery then we know for sure that the results were
// capped.
const aggregationResultsCappedByMaxDocumentsLimit =
this.config.maxDocumentsPerQuery > 0 &&
!!totalDocuments &&
totalDocuments > this.config.maxDocumentsPerQuery;

return {
content: formatUntrustedData(
this.generateMessage({
aggResultsCount: totalDocuments,
documents: cursorResults.documents,
appliedLimits: [
aggregationResultsCappedByMaxDocumentsLimit ? "config.maxDocumentsPerQuery" : undefined,
cursorResults.cappedBy,
].filter((limit): limit is keyof typeof CURSOR_LIMITS_TO_LLM_TEXT => !!limit),
}),
cursorResults.documents.length > 0 ? EJSON.stringify(cursorResults.documents) : undefined
),
};
} finally {
if (aggregationCursor) {
void this.safeCloseCursor(aggregationCursor);
}
}
}

private async safeCloseCursor(cursor: AggregationCursor<unknown>): Promise<void> {
try {
await cursor.close();
} catch (error) {
this.session.logger.warning({
id: LogId.mongodbCursorCloseError,
context: "aggregate tool",
message: `Error when closing the cursor - ${error instanceof Error ? error.message : String(error)}`,
});
}
}

private assertOnlyUsesPermittedStages(pipeline: Record<string, unknown>[]): void {
Expand All @@ -70,4 +128,57 @@ export class AggregateTool extends MongoDBToolBase {
}
}
}

private async countAggregationResultDocuments({
provider,
database,
collection,
pipeline,
}: {
provider: NodeDriverServiceProvider;
database: string;
collection: string;
pipeline: Document[];
}): Promise<number | undefined> {
const resultsCountAggregation = [...pipeline, { $count: "totalDocuments" }];
return await operationWithFallback(async (): Promise<number | undefined> => {
const aggregationResults = await provider
.aggregate(database, collection, resultsCountAggregation)
.maxTimeMS(AGG_COUNT_MAX_TIME_MS_CAP)
.toArray();

const documentWithCount: unknown = aggregationResults.length === 1 ? aggregationResults[0] : undefined;
const totalDocuments =
documentWithCount &&
typeof documentWithCount === "object" &&
"totalDocuments" in documentWithCount &&
typeof documentWithCount.totalDocuments === "number"
? documentWithCount.totalDocuments
: 0;

return totalDocuments;
}, undefined);
}

private generateMessage({
aggResultsCount,
documents,
appliedLimits,
}: {
aggResultsCount: number | undefined;
documents: unknown[];
appliedLimits: (keyof typeof CURSOR_LIMITS_TO_LLM_TEXT)[];
}): string {
const appliedLimitText = appliedLimits.length
? `\
while respecting the applied limits of ${appliedLimits.map((limit) => CURSOR_LIMITS_TO_LLM_TEXT[limit]).join(", ")}. \
Note to LLM: If the entire query result is required then use "export" tool to export the query results.\
`
: "";

return `\
The aggregation resulted in ${aggResultsCount === undefined ? "indeterminable number of" : aggResultsCount} documents. \
Returning ${documents.length} documents${appliedLimitText ? ` ${appliedLimitText}` : "."}\
`;
}
}
Loading
Loading