Skip to content

Commit a8de323

Browse files
committed
fix: enforce monotonic progress timeouts
1 parent 4fbcfcd commit a8de323

2 files changed

Lines changed: 81 additions & 7 deletions

File tree

packages/core/src/shared/protocol.ts

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
316316
private _notificationHandlers: Map<string, (notification: JSONRPCNotification) => Promise<void>> = new Map();
317317
private _responseHandlers: Map<number, (response: JSONRPCResultResponse | Error) => void> = new Map();
318318
private _progressHandlers: Map<number, ProgressCallback> = new Map();
319+
private _progressValues: Map<number, number> = new Map();
319320
private _timeoutInfo: Map<number, TimeoutInfo> = new Map();
320321
private _pendingDebouncedNotifications = new Set<string>();
321322

@@ -383,7 +384,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
383384
request: (request, resultSchema, options) => this._requestWithSchema(request, resultSchema, options),
384385
notification: (notification, options) => this.notification(notification, options),
385386
reportError: error => this._onerror(error),
386-
removeProgressHandler: token => this._progressHandlers.delete(token),
387+
removeProgressHandler: token => this._removeProgressHandler(token),
387388
registerHandler: (method, handler) => {
388389
const schema = getRequestSchema(method as RequestMethod);
389390
this._requestHandlers.set(method, (request, ctx) => {
@@ -460,6 +461,11 @@ export abstract class Protocol<ContextT extends BaseContext> {
460461
}
461462
}
462463

464+
private _removeProgressHandler(messageId: number): void {
465+
this._progressHandlers.delete(messageId);
466+
this._progressValues.delete(messageId);
467+
}
468+
463469
/**
464470
* Attaches to the given transport, starts it, and starts listening for messages.
465471
*
@@ -506,6 +512,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
506512
const responseHandlers = this._responseHandlers;
507513
this._responseHandlers = new Map();
508514
this._progressHandlers.clear();
515+
this._progressValues.clear();
509516
this._taskManager.onClose();
510517
this._pendingDebouncedNotifications.clear();
511518

@@ -702,14 +709,24 @@ export abstract class Protocol<ContextT extends BaseContext> {
702709

703710
const responseHandler = this._responseHandlers.get(messageId);
704711
const timeoutInfo = this._timeoutInfo.get(messageId);
712+
const lastProgress = this._progressValues.get(messageId);
713+
if (lastProgress !== undefined && params.progress <= lastProgress) {
714+
this._onerror(
715+
new Error(
716+
`Received a non-increasing progress notification for token ${progressToken}: ${params.progress} <= ${lastProgress}`
717+
)
718+
);
719+
return;
720+
}
721+
this._progressValues.set(messageId, params.progress);
705722

706723
if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) {
707724
try {
708725
this._resetTimeout(messageId);
709726
} catch (error) {
710727
// Clean up if maxTotalTimeout was exceeded
711728
this._responseHandlers.delete(messageId);
712-
this._progressHandlers.delete(messageId);
729+
this._removeProgressHandler(messageId);
713730
this._cleanupTimeout(messageId);
714731
responseHandler(error as Error);
715732
return;
@@ -738,7 +755,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
738755

739756
// Keep progress handler alive for CreateTaskResult responses
740757
if (!preserveProgress) {
741-
this._progressHandlers.delete(messageId);
758+
this._removeProgressHandler(messageId);
742759
}
743760

744761
if (isJSONRPCResultResponse(response)) {
@@ -890,7 +907,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
890907
if (responseReceived) {
891908
return;
892909
}
893-
this._progressHandlers.delete(messageId);
910+
this._removeProgressHandler(messageId);
894911

895912
this._transport
896913
?.send(
@@ -951,22 +968,22 @@ export abstract class Protocol<ContextT extends BaseContext> {
951968
let outboundQueued = false;
952969
try {
953970
const taskResult = this._taskManager.processOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, error => {
954-
this._progressHandlers.delete(messageId);
971+
this._removeProgressHandler(messageId);
955972
reject(error);
956973
});
957974
if (taskResult.queued) {
958975
outboundQueued = true;
959976
}
960977
} catch (error) {
961-
this._progressHandlers.delete(messageId);
978+
this._removeProgressHandler(messageId);
962979
reject(error);
963980
return;
964981
}
965982

966983
if (!outboundQueued) {
967984
// No related task or no module - send through transport normally
968985
this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => {
969-
this._progressHandlers.delete(messageId);
986+
this._removeProgressHandler(messageId);
970987
reject(error);
971988
});
972989
}

packages/core/test/shared/protocol.test.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,63 @@ describe('protocol tests', () => {
556556
await expect(requestPromise).resolves.toEqual({ result: 'success' });
557557
});
558558

559+
test('should not reset timeout for non-increasing progress notifications', async () => {
560+
await protocol.connect(transport);
561+
const request = { method: 'example', params: {} };
562+
const mockSchema: ZodType<{ result: string }> = z.object({
563+
result: z.string()
564+
});
565+
const onErrorMock = vi.fn();
566+
const onProgressMock = vi.fn();
567+
protocol.onerror = onErrorMock;
568+
569+
const requestPromise = testRequest(protocol, request, mockSchema, {
570+
timeout: 1000,
571+
resetTimeoutOnProgress: true,
572+
onprogress: onProgressMock
573+
});
574+
575+
vi.advanceTimersByTime(800);
576+
if (transport.onmessage) {
577+
transport.onmessage({
578+
jsonrpc: '2.0',
579+
method: 'notifications/progress',
580+
params: {
581+
progressToken: 0,
582+
progress: 50,
583+
total: 100
584+
}
585+
});
586+
}
587+
await Promise.resolve();
588+
589+
expect(onProgressMock).toHaveBeenCalledOnce();
590+
expect(onProgressMock).toHaveBeenCalledWith({
591+
progress: 50,
592+
total: 100
593+
});
594+
595+
vi.advanceTimersByTime(800);
596+
if (transport.onmessage) {
597+
transport.onmessage({
598+
jsonrpc: '2.0',
599+
method: 'notifications/progress',
600+
params: {
601+
progressToken: 0,
602+
progress: 25,
603+
total: 100
604+
}
605+
});
606+
}
607+
await Promise.resolve();
608+
609+
expect(onErrorMock).toHaveBeenCalledWith(expect.objectContaining({ message: expect.stringContaining('non-increasing') }));
610+
expect(onProgressMock).toHaveBeenCalledOnce();
611+
612+
vi.advanceTimersByTime(201);
613+
await expect(requestPromise).rejects.toThrow('Request timed out');
614+
});
615+
559616
test('should respect maxTotalTimeout', async () => {
560617
await protocol.connect(transport);
561618
const request = { method: 'example', params: {} };

0 commit comments

Comments
 (0)