Skip to content

Commit a66c0c8

Browse files
whoiskatrinthreepointone
authored andcommitted
improve stream completion handling and database sync; add error handling for SQL; updated some tests
1 parent 3d64702 commit a66c0c8

File tree

4 files changed

+157
-64
lines changed

4 files changed

+157
-64
lines changed

packages/agents/src/resumable-stream-manager.ts

Lines changed: 119 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -380,25 +380,77 @@ export class ResumableStreamManager<Message extends ChatMessage = ChatMessage> {
380380
const assistantMessageId = `assistant_${nanoid()}`;
381381
let buffer = "";
382382

383+
let completedNaturally = false;
383384
try {
384385
while (true) {
386+
// Check if stream was completed by onFinish callback
387+
const currentState = this._activeStreams.get(streamId);
388+
if (currentState?.completed) {
389+
// Ensure database state is synchronized
390+
try {
391+
this.sql`
392+
update cf_ai_http_chat_streams
393+
set fetching = 0, completed = 1, updated_at = current_timestamp
394+
where stream_id = ${streamId}
395+
`;
396+
} catch (sqlError) {
397+
console.error(
398+
`[ResumableStreamManager] Error syncing completion state for ${streamId}:`,
399+
sqlError
400+
);
401+
}
402+
403+
completedNaturally = true;
404+
break;
405+
}
406+
385407
const { done, value } = await reader.read();
386-
if (done) break;
408+
if (done) {
409+
completedNaturally = true;
410+
break;
411+
}
387412

388-
// Store raw chunk with sequence number
389-
const seq = streamState.seq++;
413+
// Store raw chunk with sequence number atomically
390414
const chunkBase64 = btoa(String.fromCharCode(...value));
391-
this.sql`
392-
insert into cf_ai_http_chat_chunks (stream_id, seq, chunk)
393-
values (${streamId}, ${seq}, ${chunkBase64})
394-
`;
395415

396-
// Update sequence in stream state
397-
this.sql`
398-
update cf_ai_http_chat_streams
399-
set seq = ${streamState.seq}, updated_at = current_timestamp
400-
where stream_id = ${streamId}
401-
`;
416+
try {
417+
// Atomically get next sequence number and insert chunk
418+
const seqResult = this.sql`
419+
update cf_ai_http_chat_streams
420+
set seq = seq + 1, updated_at = current_timestamp
421+
where stream_id = ${streamId}
422+
returning seq
423+
`;
424+
425+
const seq = Number(seqResult[0]?.seq) || streamState.seq++;
426+
427+
this.sql`
428+
insert into cf_ai_http_chat_chunks (stream_id, seq, chunk)
429+
values (${streamId}, ${seq}, ${chunkBase64})
430+
`;
431+
432+
// Update in-memory state to match database
433+
streamState.seq = seq + 1;
434+
} catch (sqlError) {
435+
console.error(
436+
`[ResumableStreamManager] SQL error for stream ${streamId}:`,
437+
sqlError
438+
);
439+
// Fall back to in-memory sequence if SQL fails
440+
const seq = streamState.seq++;
441+
try {
442+
this.sql`
443+
insert into cf_ai_http_chat_chunks (stream_id, seq, chunk)
444+
values (${streamId}, ${seq}, ${chunkBase64})
445+
`;
446+
} catch (fallbackError) {
447+
console.error(
448+
`[ResumableStreamManager] Fallback SQL error for stream ${streamId}:`,
449+
fallbackError
450+
);
451+
// Continue processing even if storage fails
452+
}
453+
}
402454

403455
// Parse for assistant message content
404456
const chunk = decoder.decode(value, { stream: true });
@@ -456,33 +508,28 @@ export class ResumableStreamManager<Message extends ChatMessage = ChatMessage> {
456508
await persistMessages([...messages, assistantMessage]);
457509
}
458510
} finally {
459-
// Mark stream as completed
460-
streamState.fetching = false;
461-
streamState.completed = true;
462-
463-
this.sql`
464-
update cf_ai_http_chat_streams
465-
set fetching = 0, completed = 1, updated_at = current_timestamp
466-
where stream_id = ${streamId}
467-
`;
468-
469-
// Close all readers/writers
470-
for (const readerOrWriter of streamState.readers) {
511+
// Only mark as completed if stream finished naturally, not if interrupted
512+
if (completedNaturally) {
513+
this._markStreamCompleted(streamId);
514+
} else {
515+
// Stream was interrupted - update fetching state but don't mark as completed
516+
const currentState = this._activeStreams.get(streamId);
517+
if (currentState) {
518+
currentState.fetching = false;
519+
}
471520
try {
472-
if (readerOrWriter instanceof WritableStreamDefaultWriter) {
473-
readerOrWriter.close();
474-
} else {
475-
// Handle ReadableStreamDefaultController
476-
if (
477-
"close" in readerOrWriter &&
478-
typeof readerOrWriter.close === "function"
479-
) {
480-
readerOrWriter.close();
481-
}
482-
}
483-
} catch {}
521+
this.sql`
522+
update cf_ai_http_chat_streams
523+
set fetching = 0, updated_at = current_timestamp
524+
where stream_id = ${streamId}
525+
`;
526+
} catch (sqlError) {
527+
console.error(
528+
`[ResumableStreamManager] Error updating fetching state for ${streamId}:`,
529+
sqlError
530+
);
531+
}
484532
}
485-
streamState.readers.clear();
486533
}
487534
}
488535

@@ -566,13 +613,29 @@ export class ResumableStreamManager<Message extends ChatMessage = ChatMessage> {
566613
}
567614
});
568615

569-
// 2. If still fetching, add to live readers
616+
// 2. Check if stream is truly complete by verifying both in-memory and database state
570617
const currentState = this._activeStreams.get(streamId);
571-
if (currentState?.fetching) {
572-
currentState.readers.add(writer);
618+
619+
// Get the latest database state to ensure consistency
620+
const dbState = this.sql`
621+
select fetching, completed from cf_ai_http_chat_streams
622+
where stream_id = ${streamId}
623+
`[0] as unknown as
624+
| Pick<StreamStateRow, "fetching" | "completed">
625+
| undefined;
626+
627+
const isStillFetching =
628+
currentState?.fetching || dbState?.fetching === 1;
629+
const isCompleted = currentState?.completed && dbState?.completed === 1;
630+
631+
if (isStillFetching && !isCompleted) {
632+
// Stream is still active, join as live reader
633+
if (currentState) {
634+
currentState.readers.add(writer);
635+
}
573636
await this._backfillGaps(streamId, writer, lastSeenSeq + 1);
574637
} else {
575-
// Stream is complete
638+
// Stream is complete, close writer
576639
await writer.close();
577640
}
578641
} catch (error) {
@@ -687,11 +750,20 @@ export class ResumableStreamManager<Message extends ChatMessage = ChatMessage> {
687750
streamState.readers.clear();
688751
}
689752

690-
this.sql`
691-
update cf_ai_http_chat_streams
692-
set fetching = 0, completed = 1, updated_at = current_timestamp
693-
where stream_id = ${streamId}
694-
`;
753+
// Update database state with error handling
754+
try {
755+
this.sql`
756+
update cf_ai_http_chat_streams
757+
set fetching = 0, completed = 1, updated_at = current_timestamp
758+
where stream_id = ${streamId}
759+
`;
760+
} catch (sqlError) {
761+
console.error(
762+
`[ResumableStreamManager] Error marking stream ${streamId} completed:`,
763+
sqlError
764+
);
765+
// Stream is still marked as completed in memory even if SQL fails
766+
}
695767

696768
// Clean up from memory after some time
697769
setTimeout(() => {

packages/agents/src/tests/resumable-streaming.test.ts

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import { createExecutionContext, env } from "cloudflare:test";
22
import { describe, it, expect, beforeEach } from "vitest";
3-
import worker, { type Env, TestResumableStreamAgent } from "./worker";
3+
import worker, { type Env } from "./worker";
44
import { nanoid } from "nanoid";
5-
import type { ExecutionContext } from "@cloudflare/workers-types";
65

76
declare module "cloudflare:test" {
87
interface ProvidedEnv extends Env {}
@@ -12,7 +11,7 @@ async function makeRequest(
1211
agentId: string,
1312
method: string,
1413
path: string,
15-
body?: any
14+
body?: unknown
1615
) {
1716
const ctx = createExecutionContext();
1817
const url = `http://example.com/agents/resumable-stream-agent/${agentId}${path}`;
@@ -31,7 +30,7 @@ async function makeRequest(
3130

3231
async function readPartialStreamChunks(
3332
response: Response,
34-
maxChunks: number = 3
33+
maxChunks = 3
3534
): Promise<{
3635
chunks: string[];
3736
reader: ReadableStreamDefaultReader<Uint8Array>;
@@ -375,18 +374,17 @@ describe("Resumable Streaming - Multiple Clients", () => {
375374
const text1 = extractTextFromSSE(chunk1);
376375
await reader1.cancel();
377376

378-
// Resume and interrupt again
377+
// Wait a bit to let the stream complete in background
378+
await new Promise((resolve) => setTimeout(resolve, 300));
379+
380+
// Resume and read from completed stream (should have all chunks now)
379381
const response2 = await makeRequest(
380382
agentId,
381383
"GET",
382384
`/stream/${customStreamId}`
383385
);
384-
const { chunks: chunk2, reader: reader2 } = await readPartialStreamChunks(
385-
response2,
386-
6
387-
);
386+
const chunk2 = await readStreamChunks(response2);
388387
const text2 = extractTextFromSSE(chunk2);
389-
await reader2.cancel();
390388

391389
// Final resume; should get complete stream
392390
const response3 = await makeRequest(
@@ -816,7 +814,7 @@ describe("Resumable Streaming - Error Handling and Edge Cases", () => {
816814
if (rawChunks.length > 10) break;
817815
}
818816
await reader.cancel();
819-
} catch (error) {
817+
} catch (_error) {
820818
// Expected when canceling
821819
}
822820
}

packages/agents/src/tests/worker.ts

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ import {
1111
type WSMessage
1212
} from "../index.ts";
1313
import { AIChatAgent } from "../ai-chat-agent.ts";
14-
import type { UIMessage as ChatMessage } from "ai";
14+
import type {
15+
UIMessage as ChatMessage,
16+
StreamTextOnFinishCallback,
17+
ToolSet
18+
} from "ai";
1519
import type { MCPClientConnection } from "../mcp/client-connection";
1620
import { AIHttpChatAgent } from "../ai-chat-agent-http.ts";
1721

@@ -404,7 +408,7 @@ export class TestResumableStreamAgent extends AIHttpChatAgent<
404408
// Track requests for testing
405409
requestHistory: Array<{ method: string; url: string; body?: unknown }> = [];
406410

407-
constructor(ctx: any, env: Env) {
411+
constructor(ctx: DurableObjectState, env: Env) {
408412
super(ctx, env);
409413

410414
// Set up some mock responses
@@ -420,7 +424,7 @@ export class TestResumableStreamAgent extends AIHttpChatAgent<
420424
}
421425

422426
async onChatMessage(
423-
onFinish: any,
427+
onFinish: StreamTextOnFinishCallback<ToolSet>,
424428
options?: { streamId?: string }
425429
): Promise<Response | undefined> {
426430
// Track the request
@@ -443,7 +447,7 @@ export class TestResumableStreamAgent extends AIHttpChatAgent<
443447
.join(" ");
444448

445449
// Find mock response or use default
446-
let responseText =
450+
const responseText =
447451
this.mockResponses.get(content.toLowerCase()) || `Echo: ${content}`;
448452

449453
// Create a streaming response
@@ -468,8 +472,22 @@ export class TestResumableStreamAgent extends AIHttpChatAgent<
468472
controller.enqueue(new TextEncoder().encode("data: [DONE]\n\n"));
469473
controller.close();
470474

471-
// Call onFinish
472-
await onFinish();
475+
// Call onFinish with mock result for testing
476+
await onFinish({
477+
text: responseText,
478+
toolCalls: [],
479+
toolResults: [],
480+
finishReason: "stop",
481+
usage: {
482+
promptTokens: 0,
483+
completionTokens: 0,
484+
totalTokens: 0
485+
},
486+
rawResponse: {
487+
headers: {}
488+
},
489+
warnings: []
490+
} as unknown as Parameters<typeof onFinish>[0]);
473491
}
474492
});
475493

packages/agents/src/tests/wrangler.jsonc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
"enable_nodejs_http_modules",
99
"enable_nodejs_perf_hooks_module"
1010
],
11+
12+
"observability": {
13+
"enabled": true,
14+
"head_sampling_rate": 1
15+
},
1116
"durable_objects": {
1217
"bindings": [
1318
{

0 commit comments

Comments
 (0)