Skip to content

Commit 13d8408

Browse files
chore: abort cursor iteration on request timeouts
1 parent 9d9b9f8 commit 13d8408

File tree

5 files changed

+63
-29
lines changed

5 files changed

+63
-29
lines changed

src/helpers/iterateCursor.ts

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@ import type { AggregationCursor, FindCursor } from "mongodb";
88
* The cursor is iterated until we can predict that fetching next doc won't
99
* exceed the maxBytesPerQuery limit.
1010
*/
11-
export async function iterateCursorUntilMaxBytes(
12-
cursor: FindCursor<unknown> | AggregationCursor<unknown>,
13-
maxBytesPerQuery: number
14-
): Promise<unknown[]> {
11+
export async function iterateCursorUntilMaxBytes({
12+
cursor,
13+
maxBytesPerQuery,
14+
abortSignal,
15+
}: {
16+
cursor: FindCursor<unknown> | AggregationCursor<unknown>;
17+
maxBytesPerQuery: number;
18+
abortSignal?: AbortSignal;
19+
}): Promise<unknown[]> {
1520
// Setting configured limit to zero or negative is equivalent to disabling
1621
// the max bytes limit applied on tool responses.
1722
if (maxBytesPerQuery <= 0) {
@@ -22,6 +27,10 @@ export async function iterateCursorUntilMaxBytes(
2227
let totalBytes = 0;
2328
const bufferedDocuments: unknown[] = [];
2429
while (true) {
30+
if (abortSignal?.aborted) {
31+
break;
32+
}
33+
2534
if (totalBytes + biggestDocSizeSoFar >= maxBytesPerQuery) {
2635
break;
2736
}

src/tools/mongodb/read/aggregate.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import type { AggregationCursor } from "mongodb";
33
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
44
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
55
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
6-
import type { ToolArgs, OperationType } from "../../tool.js";
6+
import type { ToolArgs, OperationType, ToolExecutionContext } from "../../tool.js";
77
import { formatUntrustedData } from "../../tool.js";
88
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
99
import { type Document, EJSON } from "bson";
@@ -25,11 +25,10 @@ export class AggregateTool extends MongoDBToolBase {
2525
};
2626
public operationType: OperationType = "read";
2727

28-
protected async execute({
29-
database,
30-
collection,
31-
pipeline,
32-
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
28+
protected async execute(
29+
{ database, collection, pipeline }: ToolArgs<typeof this.argsShape>,
30+
{ signal }: ToolExecutionContext
31+
): Promise<CallToolResult> {
3332
let aggregationCursor: AggregationCursor | undefined;
3433
try {
3534
const provider = await this.ensureConnected();
@@ -53,7 +52,11 @@ export class AggregateTool extends MongoDBToolBase {
5352

5453
const [totalDocuments, documents] = await Promise.all([
5554
this.countAggregationResultDocuments({ provider, database, collection, pipeline }),
56-
iterateCursorUntilMaxBytes(aggregationCursor, this.config.maxBytesPerQuery),
55+
iterateCursorUntilMaxBytes({
56+
cursor: aggregationCursor,
57+
maxBytesPerQuery: this.config.maxDocumentsPerQuery,
58+
abortSignal: signal,
59+
}),
5760
]);
5861

5962
let messageDescription = `\

src/tools/mongodb/read/find.ts

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { z } from "zod";
22
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
33
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
4-
import type { ToolArgs, OperationType } from "../../tool.js";
4+
import type { ToolArgs, OperationType, ToolExecutionContext } from "../../tool.js";
55
import { formatUntrustedData } from "../../tool.js";
66
import type { FindCursor, SortDirection } from "mongodb";
77
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
@@ -40,14 +40,10 @@ export class FindTool extends MongoDBToolBase {
4040
};
4141
public operationType: OperationType = "read";
4242

43-
protected async execute({
44-
database,
45-
collection,
46-
filter,
47-
projection,
48-
limit,
49-
sort,
50-
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
43+
protected async execute(
44+
{ database, collection, filter, projection, limit, sort }: ToolArgs<typeof this.argsShape>,
45+
{ signal }: ToolExecutionContext
46+
): Promise<CallToolResult> {
5147
let findCursor: FindCursor<unknown> | undefined;
5248
try {
5349
const provider = await this.ensureConnected();
@@ -82,7 +78,11 @@ export class FindTool extends MongoDBToolBase {
8278
}),
8379
undefined
8480
),
85-
iterateCursorUntilMaxBytes(findCursor, this.config.maxBytesPerQuery),
81+
iterateCursorUntilMaxBytes({
82+
cursor: findCursor,
83+
maxBytesPerQuery: this.config.maxBytesPerQuery,
84+
abortSignal: signal,
85+
}),
8686
]);
8787

8888
let messageDescription = `\

src/tools/tool.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import type { Server } from "../server.js";
1111

1212
export type ToolArgs<Args extends ZodRawShape> = z.objectOutputType<Args, ZodNever>;
1313

14+
export type ToolExecutionContext<Args extends ZodRawShape = ZodRawShape> = Parameters<ToolCallback<Args>>[1];
15+
1416
export type OperationType = "metadata" | "read" | "create" | "delete" | "update" | "connect";
1517
export type ToolCategory = "mongodb" | "atlas";
1618
export type TelemetryToolMetadata = {

tests/unit/helpers/iterateCursor.test.ts

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@ import { calculateObjectSize } from "bson";
44
import { iterateCursorUntilMaxBytes } from "../../../src/helpers/iterateCursor.js";
55

66
describe("iterateCursorUntilMaxBytes", () => {
7-
function createMockCursor(docs: unknown[]): FindCursor<unknown> {
7+
function createMockCursor(
8+
docs: unknown[],
9+
{ abortController, abortOnIdx }: { abortController?: AbortController; abortOnIdx?: number } = {}
10+
): FindCursor<unknown> {
811
let idx = 0;
912
return {
1013
tryNext: vi.fn(() => {
14+
if (idx === abortOnIdx) {
15+
abortController?.abort();
16+
}
17+
1118
if (idx < docs.length) {
1219
return Promise.resolve(docs[idx++]);
1320
}
@@ -23,23 +30,36 @@ describe("iterateCursorUntilMaxBytes", () => {
2330
const docs = Array.from({ length: 1000 }).map((_, idx) => ({ value: idx }));
2431
const cursor = createMockCursor(docs);
2532
const maxBytes = -1;
26-
const result = await iterateCursorUntilMaxBytes(cursor, maxBytes);
33+
const result = await iterateCursorUntilMaxBytes({ cursor, maxBytesPerQuery: maxBytes });
2734
expect(result).toEqual(docs);
2835
});
2936

3037
it("returns all docs if maxBytesPerQuery is 0", async () => {
3138
const docs = Array.from({ length: 1000 }).map((_, idx) => ({ value: idx }));
3239
const cursor = createMockCursor(docs);
3340
const maxBytes = 0;
34-
const result = await iterateCursorUntilMaxBytes(cursor, maxBytes);
41+
const result = await iterateCursorUntilMaxBytes({ cursor, maxBytesPerQuery: maxBytes });
3542
expect(result).toEqual(docs);
3643
});
3744

45+
it("respects abort signal and breaks out of loop when aborted", async () => {
46+
const docs = Array.from({ length: 20 }).map((_, idx) => ({ value: idx }));
47+
const abortController = new AbortController();
48+
const cursor = createMockCursor(docs, { abortOnIdx: 9, abortController });
49+
const maxBytes = 10000;
50+
const result = await iterateCursorUntilMaxBytes({
51+
cursor,
52+
maxBytesPerQuery: maxBytes,
53+
abortSignal: abortController.signal,
54+
});
55+
expect(result).toEqual(Array.from({ length: 10 }).map((_, idx) => ({ value: idx })));
56+
});
57+
3858
it("returns all docs if under maxBytesPerQuery", async () => {
3959
const docs = [{ a: 1 }, { b: 2 }];
4060
const cursor = createMockCursor(docs);
4161
const maxBytes = 10000;
42-
const result = await iterateCursorUntilMaxBytes(cursor, maxBytes);
62+
const result = await iterateCursorUntilMaxBytes({ cursor, maxBytesPerQuery: maxBytes });
4363
expect(result).toEqual(docs);
4464
});
4565

@@ -49,20 +69,20 @@ describe("iterateCursorUntilMaxBytes", () => {
4969
const docs = [doc1, doc2];
5070
const cursor = createMockCursor(docs);
5171
const maxBytes = calculateObjectSize(doc1) + 10;
52-
const result = await iterateCursorUntilMaxBytes(cursor, maxBytes);
72+
const result = await iterateCursorUntilMaxBytes({ cursor, maxBytesPerQuery: maxBytes });
5373
expect(result).toEqual([doc1]);
5474
});
5575

5676
it("returns empty array if maxBytesPerQuery is smaller than even the first doc", async () => {
5777
const docs = [{ a: "x".repeat(100) }];
5878
const cursor = createMockCursor(docs);
59-
const result = await iterateCursorUntilMaxBytes(cursor, 10);
79+
const result = await iterateCursorUntilMaxBytes({ cursor, maxBytesPerQuery: 10 });
6080
expect(result).toEqual([]);
6181
});
6282

6383
it("handles empty cursor", async () => {
6484
const cursor = createMockCursor([]);
65-
const result = await iterateCursorUntilMaxBytes(cursor, 1000);
85+
const result = await iterateCursorUntilMaxBytes({ cursor, maxBytesPerQuery: 1000 });
6686
expect(result).toEqual([]);
6787
});
6888

@@ -73,7 +93,7 @@ describe("iterateCursorUntilMaxBytes", () => {
7393
const cursor = createMockCursor(docs);
7494
// Set maxBytes so that after doc1, biggestDocSizeSoFar would prevent fetching doc2
7595
const maxBytes = calculateObjectSize(doc1) + calculateObjectSize(doc2) - 1;
76-
const result = await iterateCursorUntilMaxBytes(cursor, maxBytes);
96+
const result = await iterateCursorUntilMaxBytes({ cursor, maxBytesPerQuery: maxBytes });
7797
// Should only include doc1, not doc2
7898
expect(result).toEqual([doc1]);
7999
});

0 commit comments

Comments
 (0)