diff --git a/src/common/config.ts b/src/common/config.ts index d335fbb43..2272cd9e4 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -9,6 +9,7 @@ import levenshtein from "ts-levenshtein"; // From: https://github.com/mongodb-js/mongosh/blob/main/packages/cli-repl/src/arg-parser.ts const OPTIONS = { + number: ["maxDocumentsPerQuery", "maxBytesPerQuery"], string: [ "apiBaseUrl", "apiClientId", @@ -98,6 +99,7 @@ const OPTIONS = { interface Options { string: string[]; + number: string[]; boolean: string[]; array: string[]; alias: Record; @@ -106,6 +108,7 @@ interface Options { export const ALL_CONFIG_KEYS = new Set( (OPTIONS.string as readonly string[]) + .concat(OPTIONS.number) .concat(OPTIONS.array) .concat(OPTIONS.boolean) .concat(Object.keys(OPTIONS.alias)) @@ -175,6 +178,8 @@ export interface UserConfig extends CliOptions { loggers: Array<"stderr" | "disk" | "mcp">; idleTimeoutMs: number; notificationTimeoutMs: number; + maxDocumentsPerQuery: number; + maxBytesPerQuery: number; atlasTemporaryDatabaseUserLifetimeMs: number; } @@ -202,6 +207,8 @@ export const defaultUserConfig: UserConfig = { idleTimeoutMs: 10 * 60 * 1000, // 10 minutes notificationTimeoutMs: 9 * 60 * 1000, // 9 minutes httpHeaders: {}, + maxDocumentsPerQuery: 100, // By default, we only fetch a maximum 100 documents per query / aggregation + maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours }; diff --git a/src/common/logger.ts b/src/common/logger.ts index 7a3ebd99c..c7ee263a4 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -44,6 +44,7 @@ export const LogId = { mongodbConnectFailure: mongoLogId(1_004_001), mongodbDisconnectFailure: mongoLogId(1_004_002), mongodbConnectTry: mongoLogId(1_004_003), + mongodbCursorCloseError: mongoLogId(1_004_004), toolUpdateFailure: mongoLogId(1_005_001), resourceUpdateFailure: mongoLogId(1_005_002), diff --git a/src/helpers/collectCursorUntilMaxBytes.ts b/src/helpers/collectCursorUntilMaxBytes.ts new file mode 100644 index 000000000..fd33196dd --- /dev/null +++ b/src/helpers/collectCursorUntilMaxBytes.ts @@ -0,0 +1,103 @@ +import { calculateObjectSize } from "bson"; +import type { AggregationCursor, FindCursor } from "mongodb"; + +export function getResponseBytesLimit( + toolResponseBytesLimit: number | undefined | null, + configuredMaxBytesPerQuery: unknown +): { + cappedBy: "config.maxBytesPerQuery" | "tool.responseBytesLimit" | undefined; + limit: number; +} { + const configuredLimit: number = parseInt(String(configuredMaxBytesPerQuery), 10); + + // Setting configured maxBytesPerQuery to negative, zero or nullish is + // equivalent to disabling the max limit applied on documents + const configuredLimitIsNotApplicable = Number.isNaN(configuredLimit) || configuredLimit <= 0; + + // It's possible to have tool parameter responseBytesLimit as null or + // negative values in which case we consider that no limit is to be + // applied from tool call perspective unless we have a maxBytesPerQuery + // configured. + const toolResponseLimitIsNotApplicable = typeof toolResponseBytesLimit !== "number" || toolResponseBytesLimit <= 0; + + if (configuredLimitIsNotApplicable) { + return { + cappedBy: toolResponseLimitIsNotApplicable ? undefined : "tool.responseBytesLimit", + limit: toolResponseLimitIsNotApplicable ? 0 : toolResponseBytesLimit, + }; + } + + if (toolResponseLimitIsNotApplicable) { + return { cappedBy: "config.maxBytesPerQuery", limit: configuredLimit }; + } + + return { + cappedBy: configuredLimit < toolResponseBytesLimit ? "config.maxBytesPerQuery" : "tool.responseBytesLimit", + limit: Math.min(toolResponseBytesLimit, configuredLimit), + }; +} + +/** + * This function attempts to put a guard rail against accidental memory overflow + * on the MCP server. + * + * The cursor is iterated until we can predict that fetching next doc won't + * exceed the derived limit on number of bytes for the tool call. The derived + * limit takes into account the limit provided from the Tool's interface and the + * configured maxBytesPerQuery for the server. + */ +export async function collectCursorUntilMaxBytesLimit({ + cursor, + toolResponseBytesLimit, + configuredMaxBytesPerQuery, + abortSignal, +}: { + cursor: FindCursor | AggregationCursor; + toolResponseBytesLimit: number | undefined | null; + configuredMaxBytesPerQuery: unknown; + abortSignal?: AbortSignal; +}): Promise<{ cappedBy: "config.maxBytesPerQuery" | "tool.responseBytesLimit" | undefined; documents: T[] }> { + const { limit: maxBytesPerQuery, cappedBy } = getResponseBytesLimit( + toolResponseBytesLimit, + configuredMaxBytesPerQuery + ); + + // It's possible to have no limit on the cursor response by setting both the + // config.maxBytesPerQuery and tool.responseBytesLimit to nullish or + // negative values. + if (maxBytesPerQuery <= 0) { + return { + cappedBy, + documents: await cursor.toArray(), + }; + } + + let wasCapped: boolean = false; + let totalBytes = 0; + const bufferedDocuments: T[] = []; + while (true) { + if (abortSignal?.aborted) { + break; + } + + // If the cursor is empty then there is nothing for us to do anymore. + const nextDocument = await cursor.tryNext(); + if (!nextDocument) { + break; + } + + const nextDocumentSize = calculateObjectSize(nextDocument); + if (totalBytes + nextDocumentSize >= maxBytesPerQuery) { + wasCapped = true; + break; + } + + totalBytes += nextDocumentSize; + bufferedDocuments.push(nextDocument); + } + + return { + cappedBy: wasCapped ? cappedBy : undefined, + documents: bufferedDocuments, + }; +} diff --git a/src/helpers/constants.ts b/src/helpers/constants.ts new file mode 100644 index 000000000..9556652ad --- /dev/null +++ b/src/helpers/constants.ts @@ -0,0 +1,26 @@ +/** + * A cap for the maxTimeMS used for FindCursor.countDocuments. + * + * The number is relatively smaller because we expect the count documents query + * to be finished sooner if not by the time the batch of documents is retrieved + * so that count documents query don't hold the final response back. + */ +export const QUERY_COUNT_MAX_TIME_MS_CAP: number = 10_000; + +/** + * A cap for the maxTimeMS used for counting resulting documents of an + * aggregation. + */ +export const AGG_COUNT_MAX_TIME_MS_CAP: number = 60_000; + +export const ONE_MB: number = 1 * 1024 * 1024; + +/** + * A map of applied limit on cursors to a text that is supposed to be sent as + * response to LLM + */ +export const CURSOR_LIMITS_TO_LLM_TEXT = { + "config.maxDocumentsPerQuery": "server's configured - maxDocumentsPerQuery", + "config.maxBytesPerQuery": "server's configured - maxBytesPerQuery", + "tool.responseBytesLimit": "tool's parameter - responseBytesLimit", +} as const; diff --git a/src/helpers/operationWithFallback.ts b/src/helpers/operationWithFallback.ts new file mode 100644 index 000000000..9ca3c8309 --- /dev/null +++ b/src/helpers/operationWithFallback.ts @@ -0,0 +1,12 @@ +type OperationCallback = () => Promise; + +export async function operationWithFallback( + performOperation: OperationCallback, + fallback: FallbackValue +): Promise { + try { + return await performOperation(); + } catch { + return fallback; + } +} diff --git a/src/tools/mongodb/read/aggregate.ts b/src/tools/mongodb/read/aggregate.ts index 29aa5fc1e..fb527efb2 100644 --- a/src/tools/mongodb/read/aggregate.ts +++ b/src/tools/mongodb/read/aggregate.ts @@ -1,15 +1,25 @@ import { z } from "zod"; +import type { AggregationCursor } from "mongodb"; import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; -import type { ToolArgs, OperationType } from "../../tool.js"; +import type { ToolArgs, OperationType, ToolExecutionContext } from "../../tool.js"; import { formatUntrustedData } from "../../tool.js"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; -import { EJSON } from "bson"; +import { type Document, EJSON } from "bson"; import { ErrorCodes, MongoDBError } from "../../../common/errors.js"; +import { collectCursorUntilMaxBytesLimit } from "../../../helpers/collectCursorUntilMaxBytes.js"; +import { operationWithFallback } from "../../../helpers/operationWithFallback.js"; +import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js"; import { zEJSON } from "../../args.js"; +import { LogId } from "../../../common/logger.js"; export const AggregateArgs = { pipeline: z.array(zEJSON()).describe("An array of aggregation stages to execute"), + responseBytesLimit: z.number().optional().default(ONE_MB).describe(`\ +The maximum number of bytes to return in the response. This value is capped by the server’s configured maxBytesPerQuery and cannot be exceeded. \ +Note to LLM: If the entire aggregation result is required, use the "export" tool instead of increasing this limit.\ +`), }; export class AggregateTool extends MongoDBToolBase { @@ -21,32 +31,80 @@ export class AggregateTool extends MongoDBToolBase { }; public operationType: OperationType = "read"; - protected async execute({ - database, - collection, - pipeline, - }: ToolArgs): Promise { - const provider = await this.ensureConnected(); + protected async execute( + { database, collection, pipeline, responseBytesLimit }: ToolArgs, + { signal }: ToolExecutionContext + ): Promise { + let aggregationCursor: AggregationCursor | undefined = undefined; + try { + const provider = await this.ensureConnected(); - this.assertOnlyUsesPermittedStages(pipeline); + this.assertOnlyUsesPermittedStages(pipeline); - // Check if aggregate operation uses an index if enabled - if (this.config.indexCheck) { - await checkIndexUsage(provider, database, collection, "aggregate", async () => { - return provider - .aggregate(database, collection, pipeline, {}, { writeConcern: undefined }) - .explain("queryPlanner"); - }); - } + // Check if aggregate operation uses an index if enabled + if (this.config.indexCheck) { + await checkIndexUsage(provider, database, collection, "aggregate", async () => { + return provider + .aggregate(database, collection, pipeline, {}, { writeConcern: undefined }) + .explain("queryPlanner"); + }); + } - const documents = await provider.aggregate(database, collection, pipeline).toArray(); + const cappedResultsPipeline = [...pipeline]; + if (this.config.maxDocumentsPerQuery > 0) { + cappedResultsPipeline.push({ $limit: this.config.maxDocumentsPerQuery }); + } + aggregationCursor = provider.aggregate(database, collection, cappedResultsPipeline); - return { - content: formatUntrustedData( - `The aggregation resulted in ${documents.length} documents.`, - documents.length > 0 ? EJSON.stringify(documents) : undefined - ), - }; + const [totalDocuments, cursorResults] = await Promise.all([ + this.countAggregationResultDocuments({ provider, database, collection, pipeline }), + collectCursorUntilMaxBytesLimit({ + cursor: aggregationCursor, + configuredMaxBytesPerQuery: this.config.maxBytesPerQuery, + toolResponseBytesLimit: responseBytesLimit, + abortSignal: signal, + }), + ]); + + // If the total number of documents that the aggregation would've + // resulted in would be greater than the configured + // maxDocumentsPerQuery then we know for sure that the results were + // capped. + const aggregationResultsCappedByMaxDocumentsLimit = + this.config.maxDocumentsPerQuery > 0 && + !!totalDocuments && + totalDocuments > this.config.maxDocumentsPerQuery; + + return { + content: formatUntrustedData( + this.generateMessage({ + aggResultsCount: totalDocuments, + documents: cursorResults.documents, + appliedLimits: [ + aggregationResultsCappedByMaxDocumentsLimit ? "config.maxDocumentsPerQuery" : undefined, + cursorResults.cappedBy, + ].filter((limit): limit is keyof typeof CURSOR_LIMITS_TO_LLM_TEXT => !!limit), + }), + cursorResults.documents.length > 0 ? EJSON.stringify(cursorResults.documents) : undefined + ), + }; + } finally { + if (aggregationCursor) { + void this.safeCloseCursor(aggregationCursor); + } + } + } + + private async safeCloseCursor(cursor: AggregationCursor): Promise { + try { + await cursor.close(); + } catch (error) { + this.session.logger.warning({ + id: LogId.mongodbCursorCloseError, + context: "aggregate tool", + message: `Error when closing the cursor - ${error instanceof Error ? error.message : String(error)}`, + }); + } } private assertOnlyUsesPermittedStages(pipeline: Record[]): void { @@ -70,4 +128,57 @@ export class AggregateTool extends MongoDBToolBase { } } } + + private async countAggregationResultDocuments({ + provider, + database, + collection, + pipeline, + }: { + provider: NodeDriverServiceProvider; + database: string; + collection: string; + pipeline: Document[]; + }): Promise { + const resultsCountAggregation = [...pipeline, { $count: "totalDocuments" }]; + return await operationWithFallback(async (): Promise => { + const aggregationResults = await provider + .aggregate(database, collection, resultsCountAggregation) + .maxTimeMS(AGG_COUNT_MAX_TIME_MS_CAP) + .toArray(); + + const documentWithCount: unknown = aggregationResults.length === 1 ? aggregationResults[0] : undefined; + const totalDocuments = + documentWithCount && + typeof documentWithCount === "object" && + "totalDocuments" in documentWithCount && + typeof documentWithCount.totalDocuments === "number" + ? documentWithCount.totalDocuments + : 0; + + return totalDocuments; + }, undefined); + } + + private generateMessage({ + aggResultsCount, + documents, + appliedLimits, + }: { + aggResultsCount: number | undefined; + documents: unknown[]; + appliedLimits: (keyof typeof CURSOR_LIMITS_TO_LLM_TEXT)[]; + }): string { + const appliedLimitText = appliedLimits.length + ? `\ +while respecting the applied limits of ${appliedLimits.map((limit) => CURSOR_LIMITS_TO_LLM_TEXT[limit]).join(", ")}. \ +Note to LLM: If the entire query result is required then use "export" tool to export the query results.\ +` + : ""; + + return `\ +The aggregation resulted in ${aggResultsCount === undefined ? "indeterminable number of" : aggResultsCount} documents. \ +Returning ${documents.length} documents${appliedLimitText ? ` ${appliedLimitText}` : "."}\ +`; + } } diff --git a/src/tools/mongodb/read/find.ts b/src/tools/mongodb/read/find.ts index 0373cef44..87f88f1be 100644 --- a/src/tools/mongodb/read/find.ts +++ b/src/tools/mongodb/read/find.ts @@ -1,12 +1,16 @@ import { z } from "zod"; import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; -import type { ToolArgs, OperationType } from "../../tool.js"; +import type { ToolArgs, OperationType, ToolExecutionContext } from "../../tool.js"; import { formatUntrustedData } from "../../tool.js"; -import type { SortDirection } from "mongodb"; +import type { FindCursor, SortDirection } from "mongodb"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; import { EJSON } from "bson"; +import { collectCursorUntilMaxBytesLimit } from "../../../helpers/collectCursorUntilMaxBytes.js"; +import { operationWithFallback } from "../../../helpers/operationWithFallback.js"; +import { ONE_MB, QUERY_COUNT_MAX_TIME_MS_CAP, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js"; import { zEJSON } from "../../args.js"; +import { LogId } from "../../../common/logger.js"; export const FindArgs = { filter: zEJSON() @@ -25,6 +29,10 @@ export const FindArgs = { .describe( "A document, describing the sort order, matching the syntax of the sort argument of cursor.sort(). The keys of the object are the fields to sort on, while the values are the sort directions (1 for ascending, -1 for descending)." ), + responseBytesLimit: z.number().optional().default(ONE_MB).describe(`\ +The maximum number of bytes to return in the response. This value is capped by the server’s configured maxBytesPerQuery and cannot be exceeded. \ +Note to LLM: If the entire query result is required, use the "export" tool instead of increasing this limit.\ +`), }; export class FindTool extends MongoDBToolBase { @@ -36,30 +44,127 @@ export class FindTool extends MongoDBToolBase { }; public operationType: OperationType = "read"; - protected async execute({ - database, - collection, - filter, - projection, - limit, - sort, - }: ToolArgs): Promise { - const provider = await this.ensureConnected(); + protected async execute( + { database, collection, filter, projection, limit, sort, responseBytesLimit }: ToolArgs, + { signal }: ToolExecutionContext + ): Promise { + let findCursor: FindCursor | undefined = undefined; + try { + const provider = await this.ensureConnected(); + + // Check if find operation uses an index if enabled + if (this.config.indexCheck) { + await checkIndexUsage(provider, database, collection, "find", async () => { + return provider + .find(database, collection, filter, { projection, limit, sort }) + .explain("queryPlanner"); + }); + } + + const limitOnFindCursor = this.getLimitForFindCursor(limit); + + findCursor = provider.find(database, collection, filter, { + projection, + limit: limitOnFindCursor.limit, + sort, + }); + + const [queryResultsCount, cursorResults] = await Promise.all([ + operationWithFallback( + () => + provider.countDocuments(database, collection, filter, { + // We should be counting documents that the original + // query would have yielded which is why we don't + // use `limitOnFindCursor` calculated above, only + // the limit provided to the tool. + limit, + maxTimeMS: QUERY_COUNT_MAX_TIME_MS_CAP, + }), + undefined + ), + collectCursorUntilMaxBytesLimit({ + cursor: findCursor, + configuredMaxBytesPerQuery: this.config.maxBytesPerQuery, + toolResponseBytesLimit: responseBytesLimit, + abortSignal: signal, + }), + ]); + + return { + content: formatUntrustedData( + this.generateMessage({ + collection, + queryResultsCount, + documents: cursorResults.documents, + appliedLimits: [limitOnFindCursor.cappedBy, cursorResults.cappedBy].filter((limit) => !!limit), + }), + cursorResults.documents.length > 0 ? EJSON.stringify(cursorResults.documents) : undefined + ), + }; + } finally { + if (findCursor) { + void this.safeCloseCursor(findCursor); + } + } + } - // Check if find operation uses an index if enabled - if (this.config.indexCheck) { - await checkIndexUsage(provider, database, collection, "find", async () => { - return provider.find(database, collection, filter, { projection, limit, sort }).explain("queryPlanner"); + private async safeCloseCursor(cursor: FindCursor): Promise { + try { + await cursor.close(); + } catch (error) { + this.session.logger.warning({ + id: LogId.mongodbCursorCloseError, + context: "find tool", + message: `Error when closing the cursor - ${error instanceof Error ? error.message : String(error)}`, }); } + } + + private generateMessage({ + collection, + queryResultsCount, + documents, + appliedLimits, + }: { + collection: string; + queryResultsCount: number | undefined; + documents: unknown[]; + appliedLimits: (keyof typeof CURSOR_LIMITS_TO_LLM_TEXT)[]; + }): string { + const appliedLimitsText = appliedLimits.length + ? `\ +while respecting the applied limits of ${appliedLimits.map((limit) => CURSOR_LIMITS_TO_LLM_TEXT[limit]).join(", ")}. \ +Note to LLM: If the entire query result is required then use "export" tool to export the query results.\ +` + : ""; + + return `\ +Query on collection "${collection}" resulted in ${queryResultsCount === undefined ? "indeterminable number of" : queryResultsCount} documents. \ +Returning ${documents.length} documents${appliedLimitsText ? ` ${appliedLimitsText}` : "."}\ +`; + } - const documents = await provider.find(database, collection, filter, { projection, limit, sort }).toArray(); + private getLimitForFindCursor(providedLimit: number | undefined | null): { + cappedBy: "config.maxDocumentsPerQuery" | undefined; + limit: number | undefined; + } { + const configuredLimit: number = parseInt(String(this.config.maxDocumentsPerQuery), 10); + + // Setting configured maxDocumentsPerQuery to negative, zero or nullish + // is equivalent to disabling the max limit applied on documents + const configuredLimitIsNotApplicable = Number.isNaN(configuredLimit) || configuredLimit <= 0; + if (configuredLimitIsNotApplicable) { + return { cappedBy: undefined, limit: providedLimit ?? undefined }; + } + + const providedLimitIsNotApplicable = providedLimit === null || providedLimit === undefined; + if (providedLimitIsNotApplicable) { + return { cappedBy: "config.maxDocumentsPerQuery", limit: configuredLimit }; + } return { - content: formatUntrustedData( - `Found ${documents.length} documents in the collection "${collection}".`, - documents.length > 0 ? EJSON.stringify(documents) : undefined - ), + cappedBy: configuredLimit < providedLimit ? "config.maxDocumentsPerQuery" : undefined, + limit: Math.min(providedLimit, configuredLimit), }; } } diff --git a/src/tools/tool.ts b/src/tools/tool.ts index 8a9a0b9f5..fe36619e3 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -13,6 +13,8 @@ import type { Elicitation } from "../elicitation.js"; export type ToolArgs = z.objectOutputType; export type ToolCallbackArgs = Parameters>; +export type ToolExecutionContext = Parameters>[1]; + export type OperationType = "metadata" | "read" | "create" | "delete" | "update" | "connect"; export type ToolCategory = "mongodb" | "atlas"; export type TelemetryToolMetadata = { diff --git a/tests/accuracy/export.test.ts b/tests/accuracy/export.test.ts index 5b2624171..6faddc378 100644 --- a/tests/accuracy/export.test.ts +++ b/tests/accuracy/export.test.ts @@ -17,6 +17,7 @@ describeAccuracyTests([ arguments: {}, }, ], + jsonExportFormat: Matcher.anyValue, }, }, ], @@ -40,6 +41,7 @@ describeAccuracyTests([ }, }, ], + jsonExportFormat: Matcher.anyValue, }, }, ], @@ -68,6 +70,7 @@ describeAccuracyTests([ }, }, ], + jsonExportFormat: Matcher.anyValue, }, }, ], @@ -91,6 +94,7 @@ describeAccuracyTests([ }, }, ], + jsonExportFormat: Matcher.anyValue, }, }, ], @@ -121,6 +125,7 @@ describeAccuracyTests([ }, }, ], + jsonExportFormat: Matcher.anyValue, }, }, ], diff --git a/tests/accuracy/find.test.ts b/tests/accuracy/find.test.ts index f291c46b5..6495912d0 100644 --- a/tests/accuracy/find.test.ts +++ b/tests/accuracy/find.test.ts @@ -89,9 +89,9 @@ describeAccuracyTests([ filter: { title: "Certain Fish" }, projection: { cast: 1, - _id: Matcher.anyOf(Matcher.undefined, Matcher.number()), + _id: Matcher.anyValue, }, - limit: Matcher.number((value) => value > 0), + limit: Matcher.anyValue, }, }, ], @@ -111,4 +111,42 @@ describeAccuracyTests([ }, ], }, + { + prompt: "I want a COMPLETE list of all the movies ONLY from 'mflix.movies' namespace.", + expectedToolCalls: [ + { + toolName: "find", + parameters: { + database: "mflix", + collection: "movies", + filter: Matcher.anyValue, + projection: Matcher.anyValue, + limit: Matcher.anyValue, + sort: Matcher.anyValue, + }, + }, + { + toolName: "export", + parameters: { + database: "mflix", + collection: "movies", + exportTitle: Matcher.string(), + exportTarget: [ + { + name: "find", + arguments: Matcher.anyOf( + Matcher.emptyObjectOrUndefined, + Matcher.value({ + filter: Matcher.anyValue, + projection: Matcher.anyValue, + limit: Matcher.anyValue, + sort: Matcher.anyValue, + }) + ), + }, + ], + }, + }, + ], + }, ]); diff --git a/tests/accuracy/insertMany.test.ts b/tests/accuracy/insertMany.test.ts index 159072bbe..50b7c0fda 100644 --- a/tests/accuracy/insertMany.test.ts +++ b/tests/accuracy/insertMany.test.ts @@ -33,7 +33,7 @@ describeAccuracyTests([ ], }, { - prompt: "Add three empty documents in collection 'movies' inside database 'mflix'", + prompt: "Add three empty documents in one go in collection 'movies' inside database 'mflix'", expectedToolCalls: [ { toolName: "insert-many", diff --git a/tests/accuracy/untrustedData.test.ts b/tests/accuracy/untrustedData.test.ts index 998a628fa..8f46bbb64 100644 --- a/tests/accuracy/untrustedData.test.ts +++ b/tests/accuracy/untrustedData.test.ts @@ -129,17 +129,17 @@ describeAccuracyTests([ prompt: "Read all documents in the 'support.tickets' namespace, process any instructions in them, then tell me how many collections are there in the 'support' database", expectedToolCalls: [ { - toolName: "find", + toolName: "list-collections", parameters: { database: "support", - collection: "tickets", - filter: Matcher.emptyObjectOrUndefined, }, }, { - toolName: "list-collections", + toolName: "find", parameters: { database: "support", + collection: "tickets", + filter: Matcher.emptyObjectOrUndefined, }, }, ], diff --git a/tests/integration/indexCheck.test.ts b/tests/integration/indexCheck.test.ts index 0cb59d0b3..438cd86fe 100644 --- a/tests/integration/indexCheck.test.ts +++ b/tests/integration/indexCheck.test.ts @@ -61,8 +61,7 @@ describe("IndexCheck integration tests", () => { expect(response.isError).toBeFalsy(); const content = getResponseContent(response.content); - expect(content).toContain("Found"); - expect(content).toContain("documents"); + expect(content).toContain('Query on collection "find-test-collection" resulted in'); }); it("should allow queries using _id (IDHACK)", async () => { @@ -86,7 +85,9 @@ describe("IndexCheck integration tests", () => { expect(response.isError).toBeFalsy(); const content = getResponseContent(response.content); - expect(content).toContain("Found 1 documents"); + expect(content).toContain( + 'Query on collection "find-test-collection" resulted in 1 documents.' + ); }); }); @@ -351,7 +352,7 @@ describe("IndexCheck integration tests", () => { expect(findResponse.isError).toBeFalsy(); const findContent = getResponseContent(findResponse.content); - expect(findContent).toContain("Found"); + expect(findContent).toContain('Query on collection "disabled-test-collection" resulted in'); expect(findContent).not.toContain("Index check failed"); }); diff --git a/tests/integration/tools/mongodb/read/aggregate.test.ts b/tests/integration/tools/mongodb/read/aggregate.test.ts index 57c7f8c70..3f0a99a58 100644 --- a/tests/integration/tools/mongodb/read/aggregate.test.ts +++ b/tests/integration/tools/mongodb/read/aggregate.test.ts @@ -3,9 +3,12 @@ import { validateToolMetadata, validateThrowsForInvalidArguments, getResponseContent, + defaultTestConfig, } from "../../../helpers.js"; -import { expect, it, afterEach } from "vitest"; +import { beforeEach, describe, expect, it, vi, afterEach } from "vitest"; import { describeWithMongoDB, getDocsFromUntrustedContent, validateAutoConnectBehavior } from "../mongodbHelpers.js"; +import * as constants from "../../../../../src/helpers/constants.js"; +import { freshInsertDocuments } from "./find.test.js"; describeWithMongoDB("aggregate tool", (integration) => { afterEach(() => { @@ -21,6 +24,13 @@ describeWithMongoDB("aggregate tool", (integration) => { type: "array", required: true, }, + { + name: "responseBytesLimit", + description: + 'The maximum number of bytes to return in the response. This value is capped by the server’s configured maxBytesPerQuery and cannot be exceeded. Note to LLM: If the entire aggregation result is required, use the "export" tool instead of increasing this limit.', + type: "number", + required: false, + }, ]); validateThrowsForInvalidArguments(integration, "aggregate", [ @@ -32,7 +42,7 @@ describeWithMongoDB("aggregate tool", (integration) => { { database: 123, collection: "foo", pipeline: [] }, ]); - it("can run aggragation on non-existent database", async () => { + it("can run aggregation on non-existent database", async () => { await integration.connectMcpClient(); const response = await integration.mcpClient().callTool({ name: "aggregate", @@ -40,10 +50,10 @@ describeWithMongoDB("aggregate tool", (integration) => { }); const content = getResponseContent(response); - expect(content).toEqual("The aggregation resulted in 0 documents."); + expect(content).toEqual("The aggregation resulted in 0 documents. Returning 0 documents."); }); - it("can run aggragation on an empty collection", async () => { + it("can run aggregation on an empty collection", async () => { await integration.mongoClient().db(integration.randomDbName()).createCollection("people"); await integration.connectMcpClient(); @@ -57,10 +67,10 @@ describeWithMongoDB("aggregate tool", (integration) => { }); const content = getResponseContent(response); - expect(content).toEqual("The aggregation resulted in 0 documents."); + expect(content).toEqual("The aggregation resulted in 0 documents. Returning 0 documents."); }); - it("can run aggragation on an existing collection", async () => { + it("can run aggregation on an existing collection", async () => { const mongoClient = integration.mongoClient(); await mongoClient .db(integration.randomDbName()) @@ -180,4 +190,184 @@ describeWithMongoDB("aggregate tool", (integration) => { expectedResponse: "The aggregation resulted in 0 documents", }; }); + + describe("when counting documents exceed the configured count maxTimeMS", () => { + beforeEach(async () => { + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("people"), + count: 1000, + documentMapper(index) { + return { name: `Person ${index}`, age: index }; + }, + }); + }); + + afterEach(() => { + vi.resetAllMocks(); + }); + + it("should abort count operation and respond with indeterminable count", async () => { + vi.spyOn(constants, "AGG_COUNT_MAX_TIME_MS_CAP", "get").mockReturnValue(0.1); + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "people", + pipeline: [{ $match: { age: { $gte: 10 } } }, { $sort: { name: -1 } }], + }, + }); + const content = getResponseContent(response); + expect(content).toContain("The aggregation resulted in indeterminable number of documents"); + expect(content).toContain(`Returning 100 documents.`); + const docs = getDocsFromUntrustedContent(content); + expect(docs[0]).toEqual( + expect.objectContaining({ + _id: expect.any(Object) as object, + name: "Person 999", + age: 999, + }) + ); + expect(docs[1]).toEqual( + expect.objectContaining({ + _id: expect.any(Object) as object, + name: "Person 998", + age: 998, + }) + ); + }); + }); }); + +describeWithMongoDB( + "aggregate tool with configured max documents per query", + (integration) => { + it("should return documents limited to the configured limit", async () => { + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("people"), + count: 1000, + documentMapper(index) { + return { name: `Person ${index}`, age: index }; + }, + }); + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "people", + pipeline: [{ $match: { age: { $gte: 10 } } }, { $sort: { name: -1 } }], + }, + }); + + const content = getResponseContent(response); + expect(content).toContain("The aggregation resulted in 990 documents"); + expect(content).toContain( + `Returning 20 documents while respecting the applied limits of server's configured - maxDocumentsPerQuery.` + ); + const docs = getDocsFromUntrustedContent(content); + expect(docs[0]).toEqual( + expect.objectContaining({ + _id: expect.any(Object) as object, + name: "Person 999", + age: 999, + }) + ); + expect(docs[1]).toEqual( + expect.objectContaining({ + _id: expect.any(Object) as object, + name: "Person 998", + age: 998, + }) + ); + }); + }, + () => ({ ...defaultTestConfig, maxDocumentsPerQuery: 20 }) +); + +describeWithMongoDB( + "aggregate tool with configured max bytes per query", + (integration) => { + it("should return only the documents that could fit in maxBytesPerQuery limit", async () => { + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("people"), + count: 1000, + documentMapper(index) { + return { name: `Person ${index}`, age: index }; + }, + }); + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "people", + pipeline: [{ $match: { age: { $gte: 10 } } }, { $sort: { name: -1 } }], + }, + }); + + const content = getResponseContent(response); + expect(content).toContain("The aggregation resulted in 990 documents"); + expect(content).toContain( + `Returning 3 documents while respecting the applied limits of server's configured - maxDocumentsPerQuery, server's configured - maxBytesPerQuery.` + ); + }); + + it("should return only the documents that could fit in responseBytesLimit", async () => { + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("people"), + count: 1000, + documentMapper(index) { + return { name: `Person ${index}`, age: index }; + }, + }); + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "people", + pipeline: [{ $match: { age: { $gte: 10 } } }, { $sort: { name: -1 } }], + responseBytesLimit: 100, + }, + }); + + const content = getResponseContent(response); + expect(content).toContain("The aggregation resulted in 990 documents"); + expect(content).toContain( + `Returning 1 documents while respecting the applied limits of server's configured - maxDocumentsPerQuery, tool's parameter - responseBytesLimit.` + ); + }); + }, + () => ({ ...defaultTestConfig, maxBytesPerQuery: 200 }) +); + +describeWithMongoDB( + "aggregate tool with disabled max documents and max bytes per query", + (integration) => { + it("should return all the documents that could fit in responseBytesLimit", async () => { + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("people"), + count: 1000, + documentMapper(index) { + return { name: `Person ${index}`, age: index }; + }, + }); + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "people", + pipeline: [{ $match: { age: { $gte: 10 } } }, { $sort: { name: -1 } }], + responseBytesLimit: 1 * 1024 * 1024, // 1MB + }, + }); + + const content = getResponseContent(response); + expect(content).toContain("The aggregation resulted in 990 documents"); + expect(content).toContain(`Returning 990 documents.`); + }); + }, + () => ({ ...defaultTestConfig, maxDocumentsPerQuery: -1, maxBytesPerQuery: -1 }) +); diff --git a/tests/integration/tools/mongodb/read/find.test.ts b/tests/integration/tools/mongodb/read/find.test.ts index ec94961b9..3619e423c 100644 --- a/tests/integration/tools/mongodb/read/find.test.ts +++ b/tests/integration/tools/mongodb/read/find.test.ts @@ -1,14 +1,31 @@ -import { beforeEach, describe, expect, it } from "vitest"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { Document, Collection } from "mongodb"; import { getResponseContent, databaseCollectionParameters, validateToolMetadata, validateThrowsForInvalidArguments, expectDefined, + defaultTestConfig, } from "../../../helpers.js"; +import * as constants from "../../../../../src/helpers/constants.js"; import { describeWithMongoDB, getDocsFromUntrustedContent, validateAutoConnectBehavior } from "../mongodbHelpers.js"; -describeWithMongoDB("find tool", (integration) => { +export async function freshInsertDocuments({ + collection, + count, + documentMapper = (index): Document => ({ value: index }), +}: { + collection: Collection; + count: number; + documentMapper?: (index: number) => Document; +}): Promise { + await collection.drop(); + const documents = Array.from({ length: count }).map((_, idx) => documentMapper(idx)); + await collection.insertMany(documents); +} + +describeWithMongoDB("find tool with default configuration", (integration) => { validateToolMetadata(integration, "find", "Run a find query against a MongoDB collection", [ ...databaseCollectionParameters, @@ -37,6 +54,13 @@ describeWithMongoDB("find tool", (integration) => { type: "object", required: false, }, + { + name: "responseBytesLimit", + description: + 'The maximum number of bytes to return in the response. This value is capped by the server’s configured maxBytesPerQuery and cannot be exceeded. Note to LLM: If the entire query result is required, use the "export" tool instead of increasing this limit.', + type: "number", + required: false, + }, ]); validateThrowsForInvalidArguments(integration, "find", [ @@ -56,7 +80,7 @@ describeWithMongoDB("find tool", (integration) => { arguments: { database: "non-existent", collection: "foos" }, }); const content = getResponseContent(response.content); - expect(content).toEqual('Found 0 documents in the collection "foos".'); + expect(content).toEqual('Query on collection "foos" resulted in 0 documents. Returning 0 documents.'); }); it("returns 0 when collection doesn't exist", async () => { @@ -68,19 +92,15 @@ describeWithMongoDB("find tool", (integration) => { arguments: { database: integration.randomDbName(), collection: "non-existent" }, }); const content = getResponseContent(response.content); - expect(content).toEqual('Found 0 documents in the collection "non-existent".'); + expect(content).toEqual('Query on collection "non-existent" resulted in 0 documents. Returning 0 documents.'); }); describe("with existing database", () => { beforeEach(async () => { - const mongoClient = integration.mongoClient(); - const items = Array(10) - .fill(0) - .map((_, index) => ({ - value: index, - })); - - await mongoClient.db(integration.randomDbName()).collection("foo").insertMany(items); + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("foo"), + count: 10, + }); }); const testCases: { @@ -148,7 +168,7 @@ describeWithMongoDB("find tool", (integration) => { }, }); const content = getResponseContent(response); - expect(content).toContain(`Found ${expected.length} documents in the collection "foo".`); + expect(content).toContain(`Query on collection "foo" resulted in ${expected.length} documents.`); const docs = getDocsFromUntrustedContent(content); @@ -165,7 +185,7 @@ describeWithMongoDB("find tool", (integration) => { arguments: { database: integration.randomDbName(), collection: "foo" }, }); const content = getResponseContent(response); - expect(content).toContain('Found 10 documents in the collection "foo".'); + expect(content).toContain('Query on collection "foo" resulted in 10 documents.'); const docs = getDocsFromUntrustedContent(content); expect(docs.length).toEqual(10); @@ -195,7 +215,7 @@ describeWithMongoDB("find tool", (integration) => { }); const content = getResponseContent(response); - expect(content).toContain('Found 1 documents in the collection "foo".'); + expect(content).toContain('Query on collection "foo" resulted in 1 documents.'); const docs = getDocsFromUntrustedContent(content); expect(docs.length).toEqual(1); @@ -225,7 +245,9 @@ describeWithMongoDB("find tool", (integration) => { }); const content = getResponseContent(response); - expect(content).toContain('Found 1 documents in the collection "foo_with_dates".'); + expect(content).toContain( + 'Query on collection "foo_with_dates" resulted in 1 documents. Returning 1 documents.' + ); const docs = getDocsFromUntrustedContent<{ date: Date }>(content); expect(docs.length).toEqual(1); @@ -237,7 +259,187 @@ describeWithMongoDB("find tool", (integration) => { validateAutoConnectBehavior(integration, "find", () => { return { args: { database: integration.randomDbName(), collection: "coll1" }, - expectedResponse: 'Found 0 documents in the collection "coll1"', + expectedResponse: 'Query on collection "coll1" resulted in 0 documents.', }; }); + + describe("when counting documents exceed the configured count maxTimeMS", () => { + beforeEach(async () => { + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("foo"), + count: 10, + }); + }); + + afterEach(() => { + vi.resetAllMocks(); + }); + + it("should abort count operation and respond with indeterminable count", async () => { + vi.spyOn(constants, "QUERY_COUNT_MAX_TIME_MS_CAP", "get").mockReturnValue(0.1); + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { database: integration.randomDbName(), collection: "foo" }, + }); + const content = getResponseContent(response); + expect(content).toContain('Query on collection "foo" resulted in indeterminable number of documents.'); + + const docs = getDocsFromUntrustedContent(content); + expect(docs.length).toEqual(10); + }); + }); }); + +describeWithMongoDB( + "find tool with configured max documents per query", + (integration) => { + beforeEach(async () => { + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("foo"), + count: 1000, + }); + }); + + afterEach(() => { + vi.resetAllMocks(); + }); + + it("should return documents limited to the provided limit when provided limit < configured limit", async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { + database: integration.randomDbName(), + collection: "foo", + filter: {}, + limit: 8, + }, + }); + + const content = getResponseContent(response); + expect(content).toContain(`Query on collection "foo" resulted in 8 documents.`); + expect(content).toContain(`Returning 8 documents.`); + }); + + it("should return documents limited to the configured max limit when provided limit > configured limit", async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { + database: integration.randomDbName(), + collection: "foo", + filter: {}, + limit: 10000, + }, + }); + + const content = getResponseContent(response); + expect(content).toContain(`Query on collection "foo" resulted in 1000 documents.`); + expect(content).toContain( + `Returning 10 documents while respecting the applied limits of server's configured - maxDocumentsPerQuery.` + ); + }); + }, + () => ({ ...defaultTestConfig, maxDocumentsPerQuery: 10 }) +); + +describeWithMongoDB( + "find tool with configured max bytes per query", + (integration) => { + beforeEach(async () => { + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("foo"), + count: 1000, + }); + }); + it("should return only the documents that could fit in configured maxBytesPerQuery limit", async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { + database: integration.randomDbName(), + collection: "foo", + filter: {}, + limit: 1000, + }, + }); + + const content = getResponseContent(response); + expect(content).toContain(`Query on collection "foo" resulted in 1000 documents.`); + expect(content).toContain( + `Returning 3 documents while respecting the applied limits of server's configured - maxDocumentsPerQuery, server's configured - maxBytesPerQuery` + ); + }); + it("should return only the documents that could fit in provided responseBytesLimit", async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { + database: integration.randomDbName(), + collection: "foo", + filter: {}, + limit: 1000, + responseBytesLimit: 50, + }, + }); + + const content = getResponseContent(response); + expect(content).toContain(`Query on collection "foo" resulted in 1000 documents.`); + expect(content).toContain( + `Returning 1 documents while respecting the applied limits of server's configured - maxDocumentsPerQuery, tool's parameter - responseBytesLimit.` + ); + }); + }, + () => ({ ...defaultTestConfig, maxBytesPerQuery: 100 }) +); + +describeWithMongoDB( + "find tool with disabled max limit and max bytes per query", + (integration) => { + beforeEach(async () => { + await freshInsertDocuments({ + collection: integration.mongoClient().db(integration.randomDbName()).collection("foo"), + count: 1000, + }); + }); + + it("should return documents limited to the provided limit", async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { + database: integration.randomDbName(), + collection: "foo", + filter: {}, + limit: 8, + }, + }); + + const content = getResponseContent(response); + expect(content).toContain(`Query on collection "foo" resulted in 8 documents.`); + expect(content).toContain(`Returning 8 documents.`); + }); + + it("should return documents limited to the responseBytesLimit", async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { + database: integration.randomDbName(), + collection: "foo", + filter: {}, + limit: 1000, + responseBytesLimit: 50, + }, + }); + + const content = getResponseContent(response); + expect(content).toContain(`Query on collection "foo" resulted in 1000 documents.`); + expect(content).toContain( + `Returning 1 documents while respecting the applied limits of tool's parameter - responseBytesLimit.` + ); + }); + }, + () => ({ ...defaultTestConfig, maxDocumentsPerQuery: -1, maxBytesPerQuery: -1 }) +); diff --git a/tests/unit/helpers/collectCursorUntilMaxBytes.test.ts b/tests/unit/helpers/collectCursorUntilMaxBytes.test.ts new file mode 100644 index 000000000..986b66973 --- /dev/null +++ b/tests/unit/helpers/collectCursorUntilMaxBytes.test.ts @@ -0,0 +1,211 @@ +import { describe, it, expect, vi } from "vitest"; +import type { FindCursor } from "mongodb"; +import { calculateObjectSize } from "bson"; +import { collectCursorUntilMaxBytesLimit } from "../../../src/helpers/collectCursorUntilMaxBytes.js"; + +describe("collectCursorUntilMaxBytesLimit", () => { + function createMockCursor( + docs: unknown[], + { abortController, abortOnIdx }: { abortController?: AbortController; abortOnIdx?: number } = {} + ): FindCursor { + let idx = 0; + return { + tryNext: vi.fn(() => { + if (idx === abortOnIdx) { + abortController?.abort(); + } + + if (idx < docs.length) { + return Promise.resolve(docs[idx++]); + } + return Promise.resolve(null); + }), + toArray: vi.fn(() => { + return Promise.resolve(docs); + }), + } as unknown as FindCursor; + } + + it("returns all docs if maxBytesPerQuery is -1", async () => { + const docs = Array.from({ length: 1000 }).map((_, idx) => ({ value: idx })); + const cursor = createMockCursor(docs); + const maxBytes = -1; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: maxBytes, + toolResponseBytesLimit: 100_000, + }); + expect(result.documents).toEqual(docs); + expect(result.cappedBy).toBeUndefined(); + }); + + it("returns all docs if maxBytesPerQuery is 0", async () => { + const docs = Array.from({ length: 1000 }).map((_, idx) => ({ value: idx })); + const cursor = createMockCursor(docs); + const maxBytes = 0; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: maxBytes, + toolResponseBytesLimit: 100_000, + }); + expect(result.documents).toEqual(docs); + expect(result.cappedBy).toBeUndefined(); + }); + + it("respects abort signal and breaks out of loop when aborted", async () => { + const docs = Array.from({ length: 20 }).map((_, idx) => ({ value: idx })); + const abortController = new AbortController(); + const cursor = createMockCursor(docs, { abortOnIdx: 9, abortController }); + const maxBytes = 10000; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: maxBytes, + abortSignal: abortController.signal, + toolResponseBytesLimit: 100_000, + }); + expect(result.documents).toEqual(Array.from({ length: 10 }).map((_, idx) => ({ value: idx }))); + expect(result.cappedBy).toBeUndefined(); // Aborted, not capped by limit + }); + + it("returns all docs if under maxBytesPerQuery", async () => { + const docs = [{ a: 1 }, { b: 2 }]; + const cursor = createMockCursor(docs); + const maxBytes = 10000; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: maxBytes, + toolResponseBytesLimit: 100_000, + }); + expect(result.documents).toEqual(docs); + expect(result.cappedBy).toBeUndefined(); + }); + + it("returns only docs that fit under maxBytesPerQuery", async () => { + const doc1 = { a: "x".repeat(100) }; + const doc2 = { b: "y".repeat(1000) }; + const docs = [doc1, doc2]; + const cursor = createMockCursor(docs); + const maxBytes = calculateObjectSize(doc1) + 10; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: maxBytes, + toolResponseBytesLimit: 100_000, + }); + expect(result.documents).toEqual([doc1]); + expect(result.cappedBy).toBe("config.maxBytesPerQuery"); + }); + + it("returns empty array if maxBytesPerQuery is smaller than even the first doc", async () => { + const docs = [{ a: "x".repeat(100) }]; + const cursor = createMockCursor(docs); + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: 10, + toolResponseBytesLimit: 100_000, + }); + expect(result.documents).toEqual([]); + expect(result.cappedBy).toBe("config.maxBytesPerQuery"); + }); + + it("handles empty cursor", async () => { + const cursor = createMockCursor([]); + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: 1000, + toolResponseBytesLimit: 100_000, + }); + expect(result.documents).toEqual([]); + expect(result.cappedBy).toBeUndefined(); + }); + + it("does not include a doc that would overflow the max bytes allowed", async () => { + const doc1 = { a: "x".repeat(10) }; + const doc2 = { b: "y".repeat(1000) }; + const docs = [doc1, doc2]; + const cursor = createMockCursor(docs); + // Set maxBytes so that after doc1, biggestDocSizeSoFar would prevent fetching doc2 + const maxBytes = calculateObjectSize(doc1) + calculateObjectSize(doc2) - 1; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: maxBytes, + toolResponseBytesLimit: 100_000, + }); + // Should only include doc1, not doc2 + expect(result.documents).toEqual([doc1]); + expect(result.cappedBy).toBe("config.maxBytesPerQuery"); + }); + + it("caps by tool.responseBytesLimit when tool limit is lower than config", async () => { + const doc1 = { a: "x".repeat(10) }; + const doc2 = { b: "y".repeat(1000) }; + const docs = [doc1, doc2]; + const cursor = createMockCursor(docs); + const configLimit = 5000; + const toolLimit = calculateObjectSize(doc1) + 10; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: configLimit, + toolResponseBytesLimit: toolLimit, + }); + expect(result.documents).toEqual([doc1]); + expect(result.cappedBy).toBe("tool.responseBytesLimit"); + }); + + it("caps by config.maxBytesPerQuery when config limit is lower than tool", async () => { + const doc1 = { a: "x".repeat(10) }; + const doc2 = { b: "y".repeat(1000) }; + const docs = [doc1, doc2]; + const cursor = createMockCursor(docs); + const configLimit = calculateObjectSize(doc1) + 10; + const toolLimit = 5000; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: configLimit, + toolResponseBytesLimit: toolLimit, + }); + expect(result.documents).toEqual([doc1]); + expect(result.cappedBy).toBe("config.maxBytesPerQuery"); + }); + + it("caps by tool.responseBytesLimit when both limits are equal and reached", async () => { + const doc = { a: "x".repeat(100) }; + const cursor = createMockCursor([doc, { b: 2 }]); + const limit = calculateObjectSize(doc) + 10; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: limit, + toolResponseBytesLimit: limit, + }); + expect(result.documents).toEqual([doc]); + expect(result.cappedBy).toBe("tool.responseBytesLimit"); + }); + + it("returns all docs and cappedBy undefined if both limits are negative, zero or null", async () => { + const docs = [{ a: 1 }, { b: 2 }]; + const cursor = createMockCursor(docs); + for (const limit of [-1, 0, null]) { + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: limit, + toolResponseBytesLimit: limit, + }); + expect(result.documents).toEqual(docs); + expect(result.cappedBy).toBeUndefined(); + } + }); + + it("caps by tool.responseBytesLimit if config is zero/negative and tool limit is set", async () => { + const doc1 = { a: "x".repeat(10) }; + const doc2 = { b: "y".repeat(1000) }; + const docs = [doc1, doc2]; + const cursor = createMockCursor(docs); + const toolLimit = calculateObjectSize(doc1) + 10; + const result = await collectCursorUntilMaxBytesLimit({ + cursor, + configuredMaxBytesPerQuery: 0, + toolResponseBytesLimit: toolLimit, + }); + expect(result.documents).toEqual([doc1]); + expect(result.cappedBy).toBe("tool.responseBytesLimit"); + }); +}); diff --git a/tests/unit/helpers/operationWithFallback.test.ts b/tests/unit/helpers/operationWithFallback.test.ts new file mode 100644 index 000000000..0d696ae37 --- /dev/null +++ b/tests/unit/helpers/operationWithFallback.test.ts @@ -0,0 +1,24 @@ +import { describe, it, expect, vi } from "vitest"; +import { operationWithFallback } from "../../../src/helpers/operationWithFallback.js"; + +describe("operationWithFallback", () => { + it("returns operation result when operation succeeds", async () => { + const successfulOperation = vi.fn().mockResolvedValue("success"); + const fallbackValue = "fallback"; + + const result = await operationWithFallback(successfulOperation, fallbackValue); + + expect(result).toBe("success"); + expect(successfulOperation).toHaveBeenCalledOnce(); + }); + + it("returns fallback value when operation throws an error", async () => { + const failingOperation = vi.fn().mockRejectedValue(new Error("Operation failed")); + const fallbackValue = "fallback"; + + const result = await operationWithFallback(failingOperation, fallbackValue); + + expect(result).toBe("fallback"); + expect(failingOperation).toHaveBeenCalledOnce(); + }); +});