Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4e7d98c
fix: add guards against possible memory overflow
himanshusinghs Sep 9, 2025
250299b
chore: fix existing tests
himanshusinghs Sep 9, 2025
f1be251
chore: tests for the new behavior
himanshusinghs Sep 10, 2025
eff03a8
chore: add missing constants files
himanshusinghs Sep 10, 2025
7d670e8
Apply suggestion from @Copilot
himanshusinghs Sep 10, 2025
8e8c3aa
chore: minor typo
himanshusinghs Sep 10, 2025
6bd5638
fix: removes default limit from find tool schema
himanshusinghs Sep 10, 2025
937908b
chore: add an accuracy test for find tool
himanshusinghs Sep 10, 2025
9d9b9f8
chore: PR feedback
himanshusinghs Sep 10, 2025
13d8408
chore: abort cursor iteration on request timeouts
himanshusinghs Sep 10, 2025
f09b4f4
chore: use correct arg in agg tool
himanshusinghs Sep 10, 2025
7354562
chore: export tool matchers
himanshusinghs Sep 10, 2025
819ed01
chore: accuracy test fixes
himanshusinghs Sep 10, 2025
21f1d3e
Merge remote-tracking branch 'origin/main' into fix/MCP-21-avoid-memo…
himanshusinghs Sep 18, 2025
25e0367
chore: PR feedback about generous config defaults
himanshusinghs Sep 19, 2025
67d3ea8
Merge remote-tracking branch 'origin/main' into fix/MCP-21-avoid-memo…
himanshusinghs Sep 19, 2025
8601c05
chore: fix tests after merge
himanshusinghs Sep 19, 2025
955b7d8
chore: account for cursor close errors
himanshusinghs Sep 19, 2025
bca4bbe
chore: remove unnecessary call
himanshusinghs Sep 19, 2025
811474e
chore: revert export changes
himanshusinghs Sep 19, 2025
e3a87b3
chore: remove eager prediction of overflow
himanshusinghs Sep 19, 2025
e1c95bd
chore: initialise cursor variables
himanshusinghs Sep 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/common/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ export interface UserConfig extends CliOptions {
loggers: Array<"stderr" | "disk" | "mcp">;
idleTimeoutMs: number;
notificationTimeoutMs: number;
maxDocumentsPerQuery: number;
maxBytesPerQuery: number;
atlasTemporaryDatabaseUserLifetimeMs: number;
}

Expand Down Expand Up @@ -202,6 +204,8 @@ export const defaultUserConfig: UserConfig = {
idleTimeoutMs: 10 * 60 * 1000, // 10 minutes
notificationTimeoutMs: 9 * 60 * 1000, // 9 minutes
httpHeaders: {},
maxDocumentsPerQuery: 10, // By default, we only fetch a maximum 10 documents per query / aggregation
maxBytesPerQuery: 1 * 1024 * 1024, // By default, we only return ~1 mb of data per query / aggregation
atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours
};

Expand Down
14 changes: 14 additions & 0 deletions src/helpers/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/**
* A cap for the maxTimeMS used for FindCursor.countDocuments.
*
* The number is relatively smaller because we expect the count documents query
* to be finished sooner if not by the time the batch of documents is retrieved
* so that count documents query don't hold the final response back.
*/
export const QUERY_COUNT_MAX_TIME_MS_CAP: number = 10_000;

/**
* A cap for the maxTimeMS used for counting resulting documents of an
* aggregation.
*/
export const AGG_COUNT_MAX_TIME_MS_CAP: number = 60_000;
54 changes: 54 additions & 0 deletions src/helpers/iterateCursor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { calculateObjectSize } from "bson";
import type { AggregationCursor, FindCursor } from "mongodb";

/**
* This function attempts to put a guard rail against accidental memory overflow
* on the MCP server.
*
* The cursor is iterated until we can predict that fetching next doc won't
* exceed the maxBytesPerQuery limit.
*/
export async function iterateCursorUntilMaxBytes({
cursor,
maxBytesPerQuery,
abortSignal,
}: {
cursor: FindCursor<unknown> | AggregationCursor<unknown>;
maxBytesPerQuery: number;
abortSignal?: AbortSignal;
}): Promise<unknown[]> {
// Setting configured limit to zero or negative is equivalent to disabling
// the max bytes limit applied on tool responses.
if (maxBytesPerQuery <= 0) {
return await cursor.toArray();
}

let biggestDocSizeSoFar = 0;
let totalBytes = 0;
const bufferedDocuments: unknown[] = [];
while (true) {
if (abortSignal?.aborted) {
break;
}

if (totalBytes + biggestDocSizeSoFar >= maxBytesPerQuery) {
break;
}

const nextDocument = await cursor.tryNext();
if (!nextDocument) {
break;
}

const nextDocumentSize = calculateObjectSize(nextDocument);
if (totalBytes + nextDocumentSize >= maxBytesPerQuery) {
break;
}

totalBytes += nextDocumentSize;
biggestDocSizeSoFar = Math.max(biggestDocSizeSoFar, nextDocumentSize);
bufferedDocuments.push(nextDocument);
}

return bufferedDocuments;
}
12 changes: 12 additions & 0 deletions src/helpers/operationWithFallback.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
type OperationCallback<OperationResult> = () => Promise<OperationResult>;

export async function operationWithFallback<OperationResult, FallbackValue>(
performOperation: OperationCallback<OperationResult>,
fallback: FallbackValue
): Promise<OperationResult | FallbackValue> {
try {
return await performOperation();
} catch {
return fallback;
}
}
115 changes: 89 additions & 26 deletions src/tools/mongodb/read/aggregate.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import { z } from "zod";
import type { AggregationCursor } from "mongodb";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import type { ToolArgs, OperationType } from "../../tool.js";
import type { ToolArgs, OperationType, ToolExecutionContext } from "../../tool.js";
import { formatUntrustedData } from "../../tool.js";
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
import { EJSON } from "bson";
import { type Document, EJSON } from "bson";
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
import { iterateCursorUntilMaxBytes } from "../../../helpers/iterateCursor.js";
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
import { AGG_COUNT_MAX_TIME_MS_CAP } from "../../../helpers/constants.js";

export const AggregateArgs = {
pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"),
Expand All @@ -20,32 +25,59 @@ export class AggregateTool extends MongoDBToolBase {
};
public operationType: OperationType = "read";

protected async execute({
database,
collection,
pipeline,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const provider = await this.ensureConnected();

this.assertOnlyUsesPermittedStages(pipeline);

// Check if aggregate operation uses an index if enabled
if (this.config.indexCheck) {
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
return provider
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
.explain("queryPlanner");
});
}
protected async execute(
{ database, collection, pipeline }: ToolArgs<typeof this.argsShape>,
{ signal }: ToolExecutionContext
): Promise<CallToolResult> {
let aggregationCursor: AggregationCursor | undefined;
try {
const provider = await this.ensureConnected();

this.assertOnlyUsesPermittedStages(pipeline);

const documents = await provider.aggregate(database, collection, pipeline).toArray();
// Check if aggregate operation uses an index if enabled
if (this.config.indexCheck) {
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
return provider
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
.explain("queryPlanner");
});
}

const cappedResultsPipeline = [...pipeline];
if (this.config.maxDocumentsPerQuery > 0) {
cappedResultsPipeline.push({ $limit: this.config.maxDocumentsPerQuery });
}
aggregationCursor = provider.aggregate(database, collection, cappedResultsPipeline);

return {
content: formatUntrustedData(
`The aggregation resulted in ${documents.length} documents.`,
documents.length > 0 ? EJSON.stringify(documents) : undefined
),
};
const [totalDocuments, documents] = await Promise.all([
this.countAggregationResultDocuments({ provider, database, collection, pipeline }),
iterateCursorUntilMaxBytes({
cursor: aggregationCursor,
maxBytesPerQuery: this.config.maxBytesPerQuery,
abortSignal: signal,
}),
]);

let messageDescription = `\
The aggregation resulted in ${totalDocuments === undefined ? "indeterminable number of" : totalDocuments} documents.\
`;
if (documents.length) {
messageDescription += ` \
Returning ${documents.length} documents while respecting the applied limits. \
Note to LLM: If entire aggregation result is needed then use "export" tool to export the aggregation results.\
`;
}

return {
content: formatUntrustedData(
messageDescription,
documents.length > 0 ? EJSON.stringify(documents) : undefined
),
};
} finally {
await aggregationCursor?.close();
}
}

private assertOnlyUsesPermittedStages(pipeline: Record<string, unknown>[]): void {
Expand All @@ -69,4 +101,35 @@ export class AggregateTool extends MongoDBToolBase {
}
}
}

private async countAggregationResultDocuments({
provider,
database,
collection,
pipeline,
}: {
provider: NodeDriverServiceProvider;
database: string;
collection: string;
pipeline: Document[];
}): Promise<number | undefined> {
const resultsCountAggregation = [...pipeline, { $count: "totalDocuments" }];
return await operationWithFallback(async (): Promise<number | undefined> => {
const aggregationResults = await provider
.aggregate(database, collection, resultsCountAggregation)
.maxTimeMS(AGG_COUNT_MAX_TIME_MS_CAP)
.toArray();

const documentWithCount: unknown = aggregationResults.length === 1 ? aggregationResults[0] : undefined;
const totalDocuments =
documentWithCount &&
typeof documentWithCount === "object" &&
"totalDocuments" in documentWithCount &&
typeof documentWithCount.totalDocuments === "number"
? documentWithCount.totalDocuments
: 0;

return totalDocuments;
}, undefined);
}
}
7 changes: 1 addition & 6 deletions src/tools/mongodb/read/export.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ export class ExportTool extends MongoDBToolBase {
name: z
.literal("find")
.describe("The literal name 'find' to represent a find cursor as target."),
arguments: z
.object({
...FindArgs,
limit: FindArgs.limit.removeDefault(),
})
.describe("The arguments for 'find' operation."),
arguments: z.object(FindArgs).describe("The arguments for 'find' operation."),
}),
z.object({
name: z
Expand Down
101 changes: 78 additions & 23 deletions src/tools/mongodb/read/find.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { z } from "zod";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import type { ToolArgs, OperationType } from "../../tool.js";
import type { ToolArgs, OperationType, ToolExecutionContext } from "../../tool.js";
import { formatUntrustedData } from "../../tool.js";
import type { SortDirection } from "mongodb";
import type { FindCursor, SortDirection } from "mongodb";
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
import { EJSON } from "bson";
import { iterateCursorUntilMaxBytes } from "../../../helpers/iterateCursor.js";
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
import { QUERY_COUNT_MAX_TIME_MS_CAP } from "../../../helpers/constants.js";

export const FindArgs = {
filter: z
Expand All @@ -18,7 +21,7 @@ export const FindArgs = {
.passthrough()
.optional()
.describe("The projection, matching the syntax of the projection argument of db.collection.find()"),
limit: z.number().optional().default(10).describe("The maximum number of documents to return"),
limit: z.number().optional().describe("The maximum number of documents to return"),
sort: z
.object({})
.catchall(z.custom<SortDirection>())
Expand All @@ -37,30 +40,82 @@ export class FindTool extends MongoDBToolBase {
};
public operationType: OperationType = "read";

protected async execute({
database,
collection,
filter,
projection,
limit,
sort,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
protected async execute(
{ database, collection, filter, projection, limit, sort }: ToolArgs<typeof this.argsShape>,
{ signal }: ToolExecutionContext
): Promise<CallToolResult> {
let findCursor: FindCursor<unknown> | undefined;
try {
const provider = await this.ensureConnected();

// Check if find operation uses an index if enabled
if (this.config.indexCheck) {
await checkIndexUsage(provider, database, collection, "find", async () => {
return provider.find(database, collection, filter, { projection, limit, sort }).explain("queryPlanner");
// Check if find operation uses an index if enabled
if (this.config.indexCheck) {
await checkIndexUsage(provider, database, collection, "find", async () => {
return provider
.find(database, collection, filter, { projection, limit, sort })
.explain("queryPlanner");
});
}

const limitOnFindCursor = this.getLimitForFindCursor(limit);

findCursor = provider.find(database, collection, filter, {
projection,
limit: limitOnFindCursor,
sort,
});

const [queryResultsCount, documents] = await Promise.all([
operationWithFallback(
() =>
provider.countDocuments(database, collection, filter, {
// We should be counting documents that the original
// query would have yielded which is why we don't
// use `limitOnFindCursor` calculated above, only
// the limit provided to the tool.
limit,
maxTimeMS: QUERY_COUNT_MAX_TIME_MS_CAP,
}),
undefined
),
iterateCursorUntilMaxBytes({
cursor: findCursor,
maxBytesPerQuery: this.config.maxBytesPerQuery,
abortSignal: signal,
}),
]);

let messageDescription = `\
Query on collection "${collection}" resulted in ${queryResultsCount === undefined ? "indeterminable number of" : queryResultsCount} documents.\
`;
if (documents.length) {
messageDescription += ` \
Returning ${documents.length} documents while respecting the applied limits. \
Note to LLM: If entire query result is needed then use "export" tool to export the query results.\
`;
}

return {
content: formatUntrustedData(
messageDescription,
documents.length > 0 ? EJSON.stringify(documents) : undefined
),
};
} finally {
await findCursor?.close();
}
}

const documents = await provider.find(database, collection, filter, { projection, limit, sort }).toArray();
private getLimitForFindCursor(providedLimit: number | undefined): number | undefined {
const configuredLimit = this.config.maxDocumentsPerQuery;
// Setting configured limit to negative or zero is equivalent to
// disabling the max limit applied on documents
if (configuredLimit <= 0) {
return providedLimit;
}

return {
content: formatUntrustedData(
`Found ${documents.length} documents in the collection "${collection}".`,
documents.length > 0 ? EJSON.stringify(documents) : undefined
),
};
return providedLimit === null || providedLimit === undefined
? configuredLimit
: Math.min(providedLimit, configuredLimit);
}
}
2 changes: 2 additions & 0 deletions src/tools/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import type { Elicitation } from "../elicitation.js";
export type ToolArgs<Args extends ZodRawShape> = z.objectOutputType<Args, ZodNever>;
export type ToolCallbackArgs<Args extends ZodRawShape> = Parameters<ToolCallback<Args>>;

export type ToolExecutionContext<Args extends ZodRawShape = ZodRawShape> = Parameters<ToolCallback<Args>>[1];

export type OperationType = "metadata" | "read" | "create" | "delete" | "update" | "connect";
export type ToolCategory = "mongodb" | "atlas";
export type TelemetryToolMetadata = {
Expand Down
Loading
Loading