Skip to content

Commit be7ab37

Browse files
committed
progress manager split
1 parent 92f5ce5 commit be7ab37

File tree

4 files changed

+193
-31
lines changed

4 files changed

+193
-31
lines changed

packages/core/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ export * from './auth/errors.js';
22
export * from './shared/auth.js';
33
export * from './shared/authUtils.js';
44
export * from './shared/metadataUtils.js';
5+
export * from './shared/progressManager.js';
56
export * from './shared/protocol.js';
67
export * from './shared/responseMessage.js';
78
export * from './shared/stdio.js';
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/**
2+
* Progress Manager
3+
*
4+
* Manages progress tracking for the Protocol class.
5+
* Extracted from Protocol to follow Single Responsibility Principle.
6+
*/
7+
8+
import type { Progress, ProgressNotification } from '../types/types.js';
9+
10+
/**
11+
* Callback for progress notifications.
12+
*/
13+
export type ProgressCallback = (progress: Progress) => void;
14+
15+
/**
16+
* Manages progress tracking for requests.
17+
*
18+
* This class handles registration, lookup, and invocation of progress callbacks,
19+
* as well as task-to-progress-token associations for long-running task operations.
20+
*
21+
* @example
22+
* ```typescript
23+
* const progressManager = new ProgressManager();
24+
*
25+
* // Register a progress handler for a request
26+
* progressManager.registerHandler(messageId, (progress) => {
27+
* console.log(`Progress: ${progress.progress}/${progress.total}`);
28+
* });
29+
*
30+
* // Handle incoming progress notification
31+
* progressManager.handleProgress(notification);
32+
*
33+
* // Clean up when done
34+
* progressManager.removeHandler(messageId);
35+
* ```
36+
*/
37+
export class ProgressManager {
38+
/**
39+
* Maps message IDs to progress callbacks.
40+
*/
41+
#progressHandlers: Map<number, ProgressCallback> = new Map();
42+
43+
/**
44+
* Maps task IDs to progress tokens to keep handlers alive after CreateTaskResult.
45+
*/
46+
#taskProgressTokens: Map<string, number> = new Map();
47+
48+
/**
49+
* Registers a progress callback for a message.
50+
*
51+
* @param messageId - The message ID (used as progress token)
52+
* @param callback - The callback to invoke when progress is received
53+
*/
54+
registerHandler(messageId: number, callback: ProgressCallback): void {
55+
this.#progressHandlers.set(messageId, callback);
56+
}
57+
58+
/**
59+
* Gets the progress callback for a message.
60+
*
61+
* @param messageId - The message ID
62+
* @returns The progress callback or undefined if not registered
63+
*/
64+
getHandler(messageId: number): ProgressCallback | undefined {
65+
return this.#progressHandlers.get(messageId);
66+
}
67+
68+
/**
69+
* Removes the progress callback for a message.
70+
*
71+
* @param messageId - The message ID
72+
*/
73+
removeHandler(messageId: number): void {
74+
this.#progressHandlers.delete(messageId);
75+
}
76+
77+
/**
78+
* Checks if a progress handler exists for the given message ID.
79+
*
80+
* @param messageId - The message ID
81+
* @returns true if a handler is registered, false otherwise
82+
*/
83+
hasHandler(messageId: number): boolean {
84+
return this.#progressHandlers.has(messageId);
85+
}
86+
87+
/**
88+
* Handles an incoming progress notification by invoking the registered callback.
89+
* Returns true if the progress was handled, false if no handler was found.
90+
*
91+
* @param notification - The progress notification
92+
* @returns true if handled, false otherwise
93+
*/
94+
handleProgress(notification: ProgressNotification): boolean {
95+
const token = notification.params.progressToken;
96+
if (typeof token !== 'number') {
97+
// Token must be a number for our internal tracking
98+
return false;
99+
}
100+
101+
const callback = this.#progressHandlers.get(token);
102+
if (callback) {
103+
callback({
104+
progress: notification.params.progress,
105+
total: notification.params.total,
106+
message: notification.params.message
107+
});
108+
return true;
109+
}
110+
111+
return false;
112+
}
113+
114+
/**
115+
* Links a task ID to a progress token.
116+
* This keeps the progress handler alive after CreateTaskResult is returned,
117+
* allowing progress notifications to continue for long-running tasks.
118+
*
119+
* @param taskId - The task identifier
120+
* @param progressToken - The progress token (message ID)
121+
*/
122+
linkTaskToProgressToken(taskId: string, progressToken: number): void {
123+
this.#taskProgressTokens.set(taskId, progressToken);
124+
}
125+
126+
/**
127+
* Gets the progress token associated with a task.
128+
*
129+
* @param taskId - The task identifier
130+
* @returns The progress token or undefined if not linked
131+
*/
132+
getTaskProgressToken(taskId: string): number | undefined {
133+
return this.#taskProgressTokens.get(taskId);
134+
}
135+
136+
/**
137+
* Cleans up the progress handler associated with a task.
138+
* Should be called when a task reaches a terminal status.
139+
*
140+
* @param taskId - The task identifier
141+
*/
142+
cleanupTaskProgressHandler(taskId: string): void {
143+
const progressToken = this.#taskProgressTokens.get(taskId);
144+
if (progressToken !== undefined) {
145+
this.#progressHandlers.delete(progressToken);
146+
this.#taskProgressTokens.delete(taskId);
147+
}
148+
}
149+
150+
/**
151+
* Clears all progress handlers and task progress tokens.
152+
* Typically called when the connection is closed.
153+
*/
154+
clear(): void {
155+
this.#progressHandlers.clear();
156+
this.#taskProgressTokens.clear();
157+
}
158+
159+
/**
160+
* Gets the number of active progress handlers.
161+
*/
162+
get handlerCount(): number {
163+
return this.#progressHandlers.size;
164+
}
165+
166+
/**
167+
* Gets the number of active task-to-progress-token links.
168+
*/
169+
get taskTokenCount(): number {
170+
return this.#taskProgressTokens.size;
171+
}
172+
}

packages/core/src/shared/protocol.ts

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import type {
1616
Notification,
1717
NotificationMethod,
1818
NotificationTypeMap,
19-
Progress,
2019
ProgressNotification,
2120
RelatedTaskMetadata,
2221
Request,
@@ -51,14 +50,11 @@ import {
5150
import type { AnySchema, SchemaOutput } from '../util/zodCompat.js';
5251
import { safeParse } from '../util/zodCompat.js';
5352
import { parseWithCompat } from '../util/zodJsonSchemaCompat.js';
53+
import type { ProgressCallback } from './progressManager.js';
54+
import { ProgressManager } from './progressManager.js';
5455
import type { ResponseMessage } from './responseMessage.js';
5556
import type { Transport, TransportSendOptions } from './transport.js';
5657

57-
/**
58-
* Callback for progress notifications.
59-
*/
60-
export type ProgressCallback = (progress: Progress) => void;
61-
6258
/**
6359
* Additional initialization options.
6460
*/
@@ -330,13 +326,10 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
330326
private _requestHandlerAbortControllers: Map<RequestId, AbortController> = new Map();
331327
private _notificationHandlers: Map<string, (notification: JSONRPCNotification) => Promise<void>> = new Map();
332328
private _responseHandlers: Map<number, (response: JSONRPCResultResponse | Error) => void> = new Map();
333-
private _progressHandlers: Map<number, ProgressCallback> = new Map();
329+
private _progressManager: ProgressManager = new ProgressManager();
334330
private _timeoutInfo: Map<number, TimeoutInfo> = new Map();
335331
private _pendingDebouncedNotifications = new Set<string>();
336332

337-
// Maps task IDs to progress tokens to keep handlers alive after CreateTaskResult
338-
private _taskProgressTokens: Map<string, number> = new Map();
339-
340333
private _taskStore?: TaskStore;
341334
private _taskMessageQueue?: TaskMessageQueue;
342335

@@ -639,8 +632,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
639632
private _onclose(): void {
640633
const responseHandlers = this._responseHandlers;
641634
this._responseHandlers = new Map();
642-
this._progressHandlers.clear();
643-
this._taskProgressTokens.clear();
635+
this._progressManager.clear();
644636
this._pendingDebouncedNotifications.clear();
645637

646638
const error = McpError.fromError(ErrorCode.ConnectionClosed, 'Connection closed');
@@ -826,11 +818,10 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
826818
}
827819

828820
private _onprogress(notification: ProgressNotification): void {
829-
const { progressToken, ...params } = notification.params;
821+
const progressToken = notification.params.progressToken;
830822
const messageId = Number(progressToken);
831823

832-
const handler = this._progressHandlers.get(messageId);
833-
if (!handler) {
824+
if (!this._progressManager.hasHandler(messageId)) {
834825
this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`));
835826
return;
836827
}
@@ -844,14 +835,14 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
844835
} catch (error) {
845836
// Clean up if maxTotalTimeout was exceeded
846837
this._responseHandlers.delete(messageId);
847-
this._progressHandlers.delete(messageId);
838+
this._progressManager.removeHandler(messageId);
848839
this._cleanupTimeout(messageId);
849840
responseHandler(error as Error);
850841
return;
851842
}
852843
}
853844

854-
handler(params);
845+
this._progressManager.handleProgress(notification);
855846
}
856847

857848
private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void {
@@ -887,13 +878,13 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
887878
const task = result.task as Record<string, unknown>;
888879
if (typeof task.taskId === 'string') {
889880
isTaskResponse = true;
890-
this._taskProgressTokens.set(task.taskId, messageId);
881+
this._progressManager.linkTaskToProgressToken(task.taskId, messageId);
891882
}
892883
}
893884
}
894885

895886
if (!isTaskResponse) {
896-
this._progressHandlers.delete(messageId);
887+
this._progressManager.removeHandler(messageId);
897888
}
898889

899890
if (isJSONRPCResultResponse(response)) {
@@ -1116,7 +1107,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
11161107
};
11171108

11181109
if (options?.onprogress) {
1119-
this._progressHandlers.set(messageId, options.onprogress);
1110+
this._progressManager.registerHandler(messageId, options.onprogress);
11201111
jsonrpcRequest.params = {
11211112
...request.params,
11221113
_meta: {
@@ -1147,7 +1138,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
11471138

11481139
const cancel = (reason: unknown) => {
11491140
this._responseHandlers.delete(messageId);
1150-
this._progressHandlers.delete(messageId);
1141+
this._progressManager.removeHandler(messageId);
11511142
this._cleanupTimeout(messageId);
11521143

11531144
this._transport
@@ -1459,11 +1450,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
14591450
* This should be called when a task reaches a terminal status.
14601451
*/
14611452
private _cleanupTaskProgressHandler(taskId: string): void {
1462-
const progressToken = this._taskProgressTokens.get(taskId);
1463-
if (progressToken !== undefined) {
1464-
this._progressHandlers.delete(progressToken);
1465-
this._taskProgressTokens.delete(taskId);
1466-
}
1453+
this._progressManager.cleanupTaskProgressHandler(taskId);
14671454
}
14681455

14691456
/**

packages/core/test/shared/protocol.test.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ interface TestProtocol {
3636
_taskMessageQueue?: TaskMessageQueue;
3737
_requestResolvers: Map<RequestId, (response: JSONRPCResultResponse | Error) => void>;
3838
_responseHandlers: Map<RequestId, (response: JSONRPCResultResponse | Error) => void>;
39-
_taskProgressTokens: Map<string, number>;
39+
_progressManager: {
40+
getTaskProgressToken(taskId: string): number | undefined;
41+
cleanupTaskProgressHandler(taskId: string): void;
42+
};
4043
_clearTaskQueue: (taskId: string, sessionId?: string) => Promise<void>;
4144
requestTaskStore: (request: Request, authInfo: unknown) => TaskStore;
4245
// Protected task methods (exposed for testing)
@@ -2564,9 +2567,8 @@ describe('Progress notification support for tasks', () => {
25642567
expect(progressCallback).toHaveBeenCalledTimes(1);
25652568

25662569
// Verify the task-progress association was created
2567-
const taskProgressTokens = (protocol as unknown as TestProtocol)._taskProgressTokens as Map<string, number>;
2568-
expect(taskProgressTokens.has(taskId)).toBe(true);
2569-
expect(taskProgressTokens.get(taskId)).toBe(progressToken);
2570+
const progressManager = (protocol as unknown as TestProtocol)._progressManager;
2571+
expect(progressManager.getTaskProgressToken(taskId)).toBe(progressToken);
25702572

25712573
// Simulate task completion by calling through the protocol's task store
25722574
// This will trigger the cleanup logic
@@ -2578,7 +2580,7 @@ describe('Progress notification support for tasks', () => {
25782580
await new Promise(resolve => setTimeout(resolve, 50));
25792581

25802582
// Verify the association was cleaned up
2581-
expect(taskProgressTokens.has(taskId)).toBe(false);
2583+
expect(progressManager.getTaskProgressToken(taskId)).toBeUndefined();
25822584

25832585
// Try to send progress notification after task completion - should be ignored
25842586
progressCallback.mockClear();

0 commit comments

Comments
 (0)