Skip to content

Commit 28415d3

Browse files
committed
Refactor progress notification timeout handling in protocol
1 parent 7f6c046 commit 28415d3

File tree

2 files changed

+140
-160
lines changed

2 files changed

+140
-160
lines changed

src/shared/protocol.test.ts

Lines changed: 128 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -63,180 +63,160 @@ describe("protocol tests", () => {
6363
expect(oncloseMock).toHaveBeenCalled();
6464
});
6565

66-
test("should reset timeout when progress notification is received", async () => {
67-
jest.useFakeTimers();
68-
69-
await protocol.connect(transport);
70-
const request = { method: "example", params: {} };
71-
const mockSchema: ZodType<{ result: string }> = z.object({
72-
result: z.string(),
66+
describe("progress notification timeout behavior", () => {
67+
beforeEach(() => {
68+
jest.useFakeTimers();
7369
});
74-
75-
const onProgressMock = jest.fn();
76-
const requestPromise = protocol.request(request, mockSchema, {
77-
timeout: 1000, // Increased timeout for more reliable testing
78-
resetTimeoutOnProgress: true,
79-
onprogress: onProgressMock,
70+
afterEach(() => {
71+
jest.useRealTimers();
8072
});
8173

82-
// Advance time close to timeout
83-
jest.advanceTimersByTime(800);
84-
85-
// Send progress notification
86-
if (transport.onmessage) {
87-
transport.onmessage({
88-
jsonrpc: "2.0",
89-
method: "notifications/progress",
90-
params: {
91-
progressToken: 0,
92-
progress: 50,
93-
total: 100,
94-
},
74+
test("should reset timeout when progress notification is received", async () => {
75+
await protocol.connect(transport);
76+
const request = { method: "example", params: {} };
77+
const mockSchema: ZodType<{ result: string }> = z.object({
78+
result: z.string(),
9579
});
96-
}
97-
98-
// Run all pending promises to ensure progress handler is called
99-
await Promise.resolve();
100-
101-
// Verify progress handler was called
102-
expect(onProgressMock).toHaveBeenCalledWith({
103-
progress: 50,
104-
total: 100,
105-
});
106-
107-
// Send success response
108-
if (transport.onmessage) {
109-
transport.onmessage({
110-
jsonrpc: "2.0",
111-
id: 0,
112-
result: { result: "success" },
80+
const onProgressMock = jest.fn();
81+
const requestPromise = protocol.request(request, mockSchema, {
82+
timeout: 1000,
83+
resetTimeoutOnProgress: true,
84+
onprogress: onProgressMock,
11385
});
114-
}
115-
116-
// Run all pending promises
117-
await Promise.resolve();
118-
119-
await expect(requestPromise).resolves.toEqual({ result: "success" });
120-
121-
jest.useRealTimers();
122-
});
123-
124-
test("should respect maxTotalTimeout", async () => {
125-
jest.useFakeTimers();
126-
127-
await protocol.connect(transport);
128-
const request = { method: "example", params: {} };
129-
const mockSchema: ZodType<{ result: string }> = z.object({
130-
result: z.string(),
131-
});
132-
133-
const onProgressMock = jest.fn();
134-
const requestPromise = protocol.request(request, mockSchema, {
135-
timeout: 1000,
136-
maxTotalTimeout: 100,
137-
resetTimeoutOnProgress: true,
138-
onprogress: onProgressMock,
139-
});
140-
141-
// Advance time beyond maxTotalTimeout
142-
jest.advanceTimersByTime(150);
143-
144-
// Send progress notification after maxTotalTimeout
145-
if (transport.onmessage) {
146-
transport.onmessage({
147-
jsonrpc: "2.0",
148-
method: "notifications/progress",
149-
params: {
150-
progressToken: 0,
151-
progress: 50,
152-
total: 100,
153-
},
86+
jest.advanceTimersByTime(800);
87+
if (transport.onmessage) {
88+
transport.onmessage({
89+
jsonrpc: "2.0",
90+
method: "notifications/progress",
91+
params: {
92+
progressToken: 0,
93+
progress: 50,
94+
total: 100,
95+
},
96+
});
97+
}
98+
await Promise.resolve();
99+
expect(onProgressMock).toHaveBeenCalledWith({
100+
progress: 50,
101+
total: 100,
154102
});
155-
}
156-
157-
await expect(requestPromise).rejects.toThrow("Maximum total timeout exceeded");
158-
expect(onProgressMock).not.toHaveBeenCalled();
159-
160-
jest.useRealTimers();
161-
});
162-
163-
test("should timeout if no progress received within timeout period", async () => {
164-
jest.useFakeTimers();
165-
166-
await protocol.connect(transport);
167-
const request = { method: "example", params: {} };
168-
const mockSchema: ZodType<{ result: string }> = z.object({
169-
result: z.string(),
170-
});
171-
172-
const requestPromise = protocol.request(request, mockSchema, {
173-
timeout: 100,
174-
resetTimeoutOnProgress: true,
175-
});
176-
177-
// Advance time beyond timeout
178-
jest.advanceTimersByTime(101);
179-
180-
await expect(requestPromise).rejects.toThrow("Request timed out");
181-
182-
jest.useRealTimers();
183-
});
184-
185-
test("should handle multiple progress notifications correctly", async () => {
186-
jest.useFakeTimers();
187-
188-
await protocol.connect(transport);
189-
const request = { method: "example", params: {} };
190-
const mockSchema: ZodType<{ result: string }> = z.object({
191-
result: z.string(),
192-
});
193-
194-
const onProgressMock = jest.fn();
195-
const requestPromise = protocol.request(request, mockSchema, {
196-
timeout: 1000,
197-
resetTimeoutOnProgress: true,
198-
onprogress: onProgressMock,
103+
jest.advanceTimersByTime(800);
104+
if (transport.onmessage) {
105+
transport.onmessage({
106+
jsonrpc: "2.0",
107+
id: 0,
108+
result: { result: "success" },
109+
});
110+
}
111+
await Promise.resolve();
112+
await expect(requestPromise).resolves.toEqual({ result: "success" });
199113
});
200114

201-
// Simulate multiple progress updates
202-
for (let i = 1; i <= 3; i++) {
203-
// Advance close to timeout
204-
jest.advanceTimersByTime(800);
115+
test("should respect maxTotalTimeout", async () => {
116+
await protocol.connect(transport);
117+
const request = { method: "example", params: {} };
118+
const mockSchema: ZodType<{ result: string }> = z.object({
119+
result: z.string(),
120+
});
121+
const onProgressMock = jest.fn();
122+
const requestPromise = protocol.request(request, mockSchema, {
123+
timeout: 1000,
124+
maxTotalTimeout: 150,
125+
resetTimeoutOnProgress: true,
126+
onprogress: onProgressMock,
127+
});
205128

206-
// Send progress notification
129+
// First progress notification should work
130+
jest.advanceTimersByTime(80);
207131
if (transport.onmessage) {
208132
transport.onmessage({
209133
jsonrpc: "2.0",
210134
method: "notifications/progress",
211135
params: {
212136
progressToken: 0,
213-
progress: i * 25,
137+
progress: 50,
214138
total: 100,
215139
},
216140
});
217141
}
218-
219-
// Verify progress handler was called
220142
await Promise.resolve();
221-
expect(onProgressMock).toHaveBeenNthCalledWith(i, {
222-
progress: i * 25,
143+
expect(onProgressMock).toHaveBeenCalledWith({
144+
progress: 50,
223145
total: 100,
224146
});
225-
}
147+
jest.advanceTimersByTime(80);
148+
if (transport.onmessage) {
149+
transport.onmessage({
150+
jsonrpc: "2.0",
151+
method: "notifications/progress",
152+
params: {
153+
progressToken: 0,
154+
progress: 75,
155+
total: 100,
156+
},
157+
});
158+
}
159+
await expect(requestPromise).rejects.toThrow("Maximum total timeout exceeded");
160+
expect(onProgressMock).toHaveBeenCalledTimes(1);
161+
});
226162

227-
// Send success response
228-
if (transport.onmessage) {
229-
transport.onmessage({
230-
jsonrpc: "2.0",
231-
id: 0,
232-
result: { result: "success" },
163+
test("should timeout if no progress received within timeout period", async () => {
164+
await protocol.connect(transport);
165+
const request = { method: "example", params: {} };
166+
const mockSchema: ZodType<{ result: string }> = z.object({
167+
result: z.string(),
233168
});
234-
}
169+
const requestPromise = protocol.request(request, mockSchema, {
170+
timeout: 100,
171+
resetTimeoutOnProgress: true,
172+
});
173+
jest.advanceTimersByTime(101);
174+
await expect(requestPromise).rejects.toThrow("Request timed out");
175+
});
235176

236-
await Promise.resolve();
237-
await expect(requestPromise).resolves.toEqual({ result: "success" });
177+
test("should handle multiple progress notifications correctly", async () => {
178+
await protocol.connect(transport);
179+
const request = { method: "example", params: {} };
180+
const mockSchema: ZodType<{ result: string }> = z.object({
181+
result: z.string(),
182+
});
183+
const onProgressMock = jest.fn();
184+
const requestPromise = protocol.request(request, mockSchema, {
185+
timeout: 1000,
186+
resetTimeoutOnProgress: true,
187+
onprogress: onProgressMock,
188+
});
238189

239-
jest.useRealTimers();
190+
// Simulate multiple progress updates
191+
for (let i = 1; i <= 3; i++) {
192+
jest.advanceTimersByTime(800);
193+
if (transport.onmessage) {
194+
transport.onmessage({
195+
jsonrpc: "2.0",
196+
method: "notifications/progress",
197+
params: {
198+
progressToken: 0,
199+
progress: i * 25,
200+
total: 100,
201+
},
202+
});
203+
}
204+
await Promise.resolve();
205+
expect(onProgressMock).toHaveBeenNthCalledWith(i, {
206+
progress: i * 25,
207+
total: 100,
208+
});
209+
}
210+
if (transport.onmessage) {
211+
transport.onmessage({
212+
jsonrpc: "2.0",
213+
id: 0,
214+
result: { result: "success" },
215+
});
216+
}
217+
await Promise.resolve();
218+
await expect(requestPromise).resolves.toEqual({ result: "success" });
219+
});
240220
});
241221
});
242222

src/shared/protocol.ts

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ export type RequestOptions = {
7373
resetTimeoutOnProgress?: boolean;
7474

7575
/**
76-
* Maximum total time (in milliseconds) to wait for a response, even if progress notifications are received.
77-
* Only used when resetTimeoutOnProgress is true.
76+
* Maximum total time (in milliseconds) to wait for a response.
77+
* If exceeded, an McpError with code `RequestTimeout` will be raised, regardless of progress notifications.
7878
* If not specified, there is no maximum total timeout.
7979
*/
8080
maxTotalTimeout?: number;
@@ -190,19 +190,18 @@ export abstract class Protocol<
190190
});
191191
}
192192

193-
private _resetTimeout(messageId: number, cancel: (reason: unknown) => void): boolean {
193+
private _resetTimeout(messageId: number): boolean {
194194
const info = this._timeoutInfo.get(messageId);
195195
if (!info) return false;
196196

197197
const totalElapsed = Date.now() - info.startTime;
198198
if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) {
199199
this._timeoutInfo.delete(messageId);
200-
cancel(new McpError(
200+
throw new McpError(
201201
ErrorCode.RequestTimeout,
202202
"Maximum total timeout exceeded",
203203
{ maxTotalTimeout: info.maxTotalTimeout, totalElapsed }
204-
));
205-
return false;
204+
);
206205
}
207206

208207
clearTimeout(info.timeoutId);
@@ -360,7 +359,12 @@ export abstract class Protocol<
360359

361360
const responseHandler = this._responseHandlers.get(messageId);
362361
if (this._timeoutInfo.has(messageId) && responseHandler) {
363-
if (!this._resetTimeout(messageId, (reason) => responseHandler(reason as Error))) {
362+
try {
363+
if (!this._resetTimeout(messageId)) {
364+
return;
365+
}
366+
} catch (error) {
367+
responseHandler(error as Error);
364368
return;
365369
}
366370
}
@@ -518,11 +522,7 @@ export abstract class Protocol<
518522
{ timeout }
519523
));
520524

521-
if (options?.resetTimeoutOnProgress) {
522-
this._setupTimeout(messageId, timeout, options.maxTotalTimeout, timeoutHandler);
523-
} else {
524-
this._setupTimeout(messageId, timeout, undefined, timeoutHandler);
525-
}
525+
this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler);
526526

527527
this._transport.send(jsonrpcRequest).catch((error) => {
528528
this._cleanupTimeout(messageId);

0 commit comments

Comments
 (0)