Skip to content

Commit c365f76

Browse files
committed
wip allowing for resumability
1 parent 4a32966 commit c365f76

File tree

2 files changed

+502
-6
lines changed

2 files changed

+502
-6
lines changed

packages/agents/src/mcp/worker-transport.ts

Lines changed: 151 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export interface TransportState {
4444
sessionId?: string;
4545
initialized: boolean;
4646
protocolVersion?: ProtocolVersion;
47+
requestMappings?: Record<string, string>;
4748
}
4849

4950
export interface WorkerTransportOptions {
@@ -76,6 +77,7 @@ export class WorkerTransport implements Transport {
7677
private protocolVersion?: ProtocolVersion;
7778
private storage?: MCPStorageApi;
7879
private stateRestored = false;
80+
private eventCounters = new Map<string, number>();
7981

8082
sessionId?: string;
8183
onclose?: () => void;
@@ -90,6 +92,51 @@ export class WorkerTransport implements Transport {
9092
this.storage = options?.storage;
9193
}
9294

95+
/**
96+
* Generate a unique event ID for SSE spec compliance.
97+
* Counter is ephemeral (not persisted) - validation only checks streamId existence.
98+
*/
99+
private generateEventId(streamId: string): string {
100+
if (!this.sessionIdGenerator) {
101+
return `session-${streamId}-${Date.now()}`;
102+
}
103+
104+
const currentCounter = this.eventCounters.get(streamId) || 0;
105+
const nextCounter = currentCounter + 1;
106+
this.eventCounters.set(streamId, nextCounter);
107+
108+
return `${this.sessionId || "session"}-${streamId}-${nextCounter}`;
109+
}
110+
111+
/**
112+
* Parse an event ID to extract streamId and counter.
113+
* Event ID format: "{sessionId}-{streamId}-{counter}"
114+
*/
115+
private parseEventId(
116+
eventId: string
117+
): { streamId: string; counter: number } | null {
118+
const sessionPrefix = this.sessionId ? `${this.sessionId}-` : "session-";
119+
if (!eventId.startsWith(sessionPrefix)) {
120+
return null;
121+
}
122+
123+
const remainder = eventId.slice(sessionPrefix.length);
124+
125+
const lastDash = remainder.lastIndexOf("-");
126+
if (lastDash === -1) {
127+
return null;
128+
}
129+
130+
const counterStr = remainder.slice(lastDash + 1);
131+
const counter = parseInt(counterStr, 10);
132+
if (isNaN(counter)) {
133+
return null;
134+
}
135+
136+
const streamId = remainder.slice(0, lastDash);
137+
return { streamId, counter };
138+
}
139+
93140
/**
94141
* Restore transport state from persistent storage.
95142
* This is automatically called on start.
@@ -107,11 +154,33 @@ export class WorkerTransport implements Transport {
107154
this.protocolVersion = state.protocolVersion;
108155
}
109156

157+
if (state?.requestMappings) {
158+
for (const [requestId, streamId] of Object.entries(
159+
state.requestMappings
160+
)) {
161+
this.requestToStreamMapping.set(requestId, streamId);
162+
}
163+
}
164+
110165
this.stateRestored = true;
111166
}
112167

168+
/**
169+
* Convert a Map to a plain Record for JSON serialization.
170+
*/
171+
private mapToRecord<K extends string | number, V>(
172+
map: Map<K, V>
173+
): Record<string, V> {
174+
const record: Record<string, V> = {};
175+
for (const [key, value] of map.entries()) {
176+
record[String(key)] = value;
177+
}
178+
return record;
179+
}
180+
113181
/**
114182
* Persist current transport state to storage.
183+
* Called on initialization and POST completion only.
115184
*/
116185
private async saveState() {
117186
if (!this.storage) {
@@ -121,7 +190,8 @@ export class WorkerTransport implements Transport {
121190
const state: TransportState = {
122191
sessionId: this.sessionId,
123192
initialized: this.initialized,
124-
protocolVersion: this.protocolVersion
193+
protocolVersion: this.protocolVersion,
194+
requestMappings: this.mapToRecord(this.requestToStreamMapping)
125195
};
126196

127197
await Promise.resolve(this.storage.set(state));
@@ -289,15 +359,73 @@ export class WorkerTransport implements Transport {
289359
return sessionError;
290360
}
291361

292-
// Validate protocol version on subsequent requests
293362
const versionError = this.validateProtocolVersion(request);
294363
if (versionError) {
295364
return versionError;
296365
}
297366

298-
const streamId = this.standaloneSseStreamId;
367+
const lastEventId = request.headers.get("Last-Event-ID");
368+
let streamId: string;
369+
let isResumingStream = false;
370+
371+
if (lastEventId) {
372+
const parsed = this.parseEventId(lastEventId);
373+
if (!parsed) {
374+
return new Response(
375+
JSON.stringify({
376+
jsonrpc: "2.0",
377+
error: {
378+
code: -32000,
379+
message: `Cannot resume: invalid event ID format "${lastEventId}".`
380+
},
381+
id: null
382+
}),
383+
{
384+
status: 404,
385+
headers: {
386+
"Content-Type": "application/json",
387+
...this.getHeaders()
388+
}
389+
}
390+
);
391+
}
392+
393+
// Validate streamId exists by deriving from requestMappings + GET stream
394+
// POST streams are valid if they have pending requests
395+
// GET stream is always valid (persistent notification channel)
396+
const validStreamIds = new Set(
397+
Array.from(this.requestToStreamMapping.values())
398+
);
399+
validStreamIds.add(this.standaloneSseStreamId);
299400

300-
if (this.streamMapping.get(streamId) !== undefined) {
401+
if (!validStreamIds.has(parsed.streamId)) {
402+
return new Response(
403+
JSON.stringify({
404+
jsonrpc: "2.0",
405+
error: {
406+
code: -32000,
407+
message: `Cannot resume: stream "${parsed.streamId}" not found. POST streams expire after responses complete.`
408+
},
409+
id: null
410+
}),
411+
{
412+
status: 404,
413+
headers: {
414+
"Content-Type": "application/json",
415+
...this.getHeaders()
416+
}
417+
}
418+
);
419+
}
420+
421+
streamId = parsed.streamId;
422+
isResumingStream = true;
423+
} else {
424+
streamId = this.standaloneSseStreamId;
425+
}
426+
427+
const existingStream = this.streamMapping.get(streamId);
428+
if (existingStream !== undefined && !isResumingStream) {
301429
return new Response(
302430
JSON.stringify({
303431
jsonrpc: "2.0",
@@ -340,6 +468,9 @@ export class WorkerTransport implements Transport {
340468
}
341469
}, 30000);
342470

471+
if (isResumingStream && existingStream) {
472+
existingStream.cleanup();
473+
}
343474
this.streamMapping.set(streamId, {
344475
writer,
345476
encoder,
@@ -350,6 +481,10 @@ export class WorkerTransport implements Transport {
350481
}
351482
});
352483

484+
const primingEventId = this.generateEventId(streamId);
485+
const primingEvent = `id: ${primingEventId}\ndata: \n\n`;
486+
writer.write(encoder.encode(primingEvent)).catch(() => {});
487+
353488
return new Response(readable, { headers });
354489
}
355490

@@ -608,6 +743,10 @@ export class WorkerTransport implements Transport {
608743
}
609744
}
610745

746+
const primingEventId = this.generateEventId(streamId);
747+
const primingEvent = `id: ${primingEventId}\ndata: \n\n`;
748+
writer.write(encoder.encode(primingEvent)).catch(() => {});
749+
611750
for (const message of messages) {
612751
this.onmessage?.(message, { requestInfo });
613752
}
@@ -738,6 +877,8 @@ export class WorkerTransport implements Transport {
738877

739878
this.streamMapping.clear();
740879
this.requestResponseMap.clear();
880+
this.requestToStreamMapping.clear();
881+
this.eventCounters.clear();
741882
this.onclose?.();
742883
}
743884

@@ -766,7 +907,8 @@ export class WorkerTransport implements Transport {
766907
}
767908

768909
if (standaloneSse.writer && standaloneSse.encoder) {
769-
const data = `event: message\ndata: ${JSON.stringify(message)}\n\n`;
910+
const eventId = this.generateEventId(this.standaloneSseStreamId);
911+
const data = `id: ${eventId}\nevent: message\ndata: ${JSON.stringify(message)}\n\n`;
770912
await standaloneSse.writer.write(standaloneSse.encoder.encode(data));
771913
}
772914
return;
@@ -788,7 +930,8 @@ export class WorkerTransport implements Transport {
788930

789931
if (!this.enableJsonResponse) {
790932
if (response.writer && response.encoder) {
791-
const data = `event: message\ndata: ${JSON.stringify(message)}\n\n`;
933+
const eventId = this.generateEventId(streamId);
934+
const data = `id: ${eventId}\nevent: message\ndata: ${JSON.stringify(message)}\n\n`;
792935
await response.writer.write(response.encoder.encode(data));
793936
}
794937
}
@@ -829,6 +972,8 @@ export class WorkerTransport implements Transport {
829972
this.requestResponseMap.delete(id);
830973
this.requestToStreamMapping.delete(id);
831974
}
975+
976+
this.saveState();
832977
}
833978
}
834979
}

0 commit comments

Comments
 (0)