Skip to content

Commit ba9e790

Browse files
authored
Fix accumulation of response parts in Chat.streamText (#5931)
1 parent 34fbbb1 commit ba9e790

File tree

3 files changed

+83
-104
lines changed

3 files changed

+83
-104
lines changed

.changeset/ripe-buses-punch.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@effect/ai": patch
3+
---
4+
5+
Fix the accumulation logic for response parts in the AI `Chat` module

packages/ai/ai/src/Chat.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import type { PersistenceBackingError } from "@effect/experimental/Persistence"
5050
import { BackingPersistence } from "@effect/experimental/Persistence"
5151
import * as Channel from "effect/Channel"
52+
import * as Chunk from "effect/Chunk"
5253
import * as Context from "effect/Context"
5354
import * as Duration from "effect/Duration"
5455
import * as Effect from "effect/Effect"
@@ -379,24 +380,26 @@ export const empty: Effect.Effect<Service> = Effect.gen(function*() {
379380
),
380381
streamText: Effect.fnUntraced(
381382
function*(options) {
382-
let combined: Prompt.Prompt = Prompt.empty
383+
let parts = Chunk.empty<Response.AnyPart>()
383384
return Stream.fromChannel(Channel.acquireUseRelease(
384385
semaphore.take(1).pipe(
385386
Effect.zipRight(Ref.get(history)),
386387
Effect.map((history) => Prompt.merge(history, Prompt.make(options.prompt)))
387388
),
388389
(prompt) =>
389390
LanguageModel.streamText({ ...options, prompt }).pipe(
390-
Stream.mapChunksEffect(Effect.fnUntraced(function*(chunk) {
391-
const parts = Array.from(chunk)
392-
combined = Prompt.merge(combined, Prompt.fromResponseParts(parts))
391+
Stream.mapChunks((chunk) => {
392+
parts = Chunk.appendAll(parts, chunk)
393393
return chunk
394-
})),
394+
}),
395395
Stream.toChannel
396396
),
397-
(parts) =>
397+
(prompt) =>
398398
Effect.zipRight(
399-
Ref.set(history, Prompt.merge(parts, combined)),
399+
Ref.set(
400+
history,
401+
Prompt.merge(prompt, Prompt.fromResponseParts(Array.from(parts)))
402+
),
400403
semaphore.release(1)
401404
)
402405
)).pipe(

packages/ai/ai/src/Prompt.ts

Lines changed: 68 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,38 +1511,6 @@ export const make = (input: RawInput): Prompt => {
15111511
*/
15121512
export const fromMessages = (messages: ReadonlyArray<Message>): Prompt => makePrompt(messages)
15131513

1514-
const VALID_RESPONSE_PART_MAP = {
1515-
"response-metadata": false,
1516-
"text": true,
1517-
"text-start": false,
1518-
"text-delta": true,
1519-
"text-end": false,
1520-
"reasoning": true,
1521-
"reasoning-start": false,
1522-
"reasoning-delta": true,
1523-
"reasoning-end": false,
1524-
"file": false,
1525-
"source": false,
1526-
"tool-params-start": false,
1527-
"tool-params-delta": false,
1528-
"tool-params-end": false,
1529-
"tool-call": true,
1530-
"tool-result": true,
1531-
"finish": false,
1532-
"error": false
1533-
} as const satisfies Record<Response.AnyPart["type"], boolean>
1534-
1535-
type ValidResponseParts = typeof VALID_RESPONSE_PART_MAP
1536-
1537-
type ValidResponsePart = {
1538-
[Type in keyof ValidResponseParts]: ValidResponseParts[Type] extends true ? Extract<Response.AnyPart, { type: Type }>
1539-
: never
1540-
}[keyof typeof VALID_RESPONSE_PART_MAP]
1541-
1542-
const isValidPart = (part: Response.AnyPart): part is ValidResponsePart => {
1543-
return VALID_RESPONSE_PART_MAP[part.type]
1544-
}
1545-
15461514
/**
15471515
* Creates a Prompt from the response parts of a previous interaction with a
15481516
* large language model.
@@ -1590,93 +1558,96 @@ export const fromResponseParts = (parts: ReadonlyArray<Response.AnyPart>): Promp
15901558
const assistantParts: Array<AssistantMessagePart> = []
15911559
const toolParts: Array<ToolMessagePart> = []
15921560

1593-
const textDeltas: Array<string> = []
1594-
function flushTextDeltas() {
1595-
if (textDeltas.length > 0) {
1596-
const text = textDeltas.join("")
1597-
if (text.length > 0) {
1598-
assistantParts.push(makePart("text", { text }))
1599-
}
1600-
textDeltas.length = 0
1601-
}
1602-
}
1561+
const activeTextDeltas = new Map<string, { text: string }>()
1562+
const activeReasoningDeltas = new Map<string, { text: string }>()
16031563

1604-
const reasoningDeltas: Array<string> = []
1605-
function flushReasoningDeltas() {
1606-
if (reasoningDeltas.length > 0) {
1607-
const text = reasoningDeltas.join("")
1608-
if (text.length > 0) {
1609-
assistantParts.push(makePart("reasoning", { text }))
1564+
for (const part of parts) {
1565+
switch (part.type) {
1566+
// Text Parts
1567+
case "text": {
1568+
assistantParts.push(makePart("text", { text: part.text }))
1569+
break
16101570
}
1611-
reasoningDeltas.length = 0
1612-
}
1613-
}
16141571

1615-
function flushDeltas() {
1616-
flushTextDeltas()
1617-
flushReasoningDeltas()
1618-
}
1619-
1620-
for (const part of parts) {
1621-
if (isValidPart(part)) {
1622-
switch (part.type) {
1623-
case "text": {
1624-
flushDeltas()
1625-
assistantParts.push(makePart("text", { text: part.text }))
1626-
break
1627-
}
1628-
case "text-delta": {
1629-
flushReasoningDeltas()
1630-
textDeltas.push(part.delta)
1631-
break
1632-
}
1633-
case "reasoning": {
1634-
flushDeltas()
1635-
assistantParts.push(makePart("reasoning", { text: part.text }))
1636-
break
1572+
// Text Parts (streaming)
1573+
case "text-start": {
1574+
activeTextDeltas.set(part.id, { text: "" })
1575+
break
1576+
}
1577+
case "text-delta": {
1578+
if (activeTextDeltas.has(part.id)) {
1579+
activeTextDeltas.get(part.id)!.text += part.delta
16371580
}
1638-
case "reasoning-delta": {
1639-
flushTextDeltas()
1640-
reasoningDeltas.push(part.delta)
1641-
break
1581+
break
1582+
}
1583+
case "text-end": {
1584+
if (activeTextDeltas.has(part.id)) {
1585+
assistantParts.push(makePart("text", activeTextDeltas.get(part.id)!))
16421586
}
1643-
case "tool-call": {
1644-
flushDeltas()
1645-
assistantParts.push(makePart("tool-call", {
1646-
id: part.id,
1647-
name: part.providerName ?? part.name,
1648-
params: part.params,
1649-
providerExecuted: part.providerExecuted ?? false
1650-
}))
1651-
break
1587+
break
1588+
}
1589+
1590+
// Reasoning Parts
1591+
case "reasoning": {
1592+
assistantParts.push(makePart("reasoning", { text: part.text }))
1593+
break
1594+
}
1595+
1596+
// Reasoning Parts (streaming)
1597+
case "reasoning-start": {
1598+
activeReasoningDeltas.set(part.id, { text: "" })
1599+
break
1600+
}
1601+
case "reasoning-delta": {
1602+
if (activeReasoningDeltas.has(part.id)) {
1603+
activeReasoningDeltas.get(part.id)!.text += part.delta
16521604
}
1653-
case "tool-result": {
1654-
flushDeltas()
1655-
toolParts.push(makePart("tool-result", {
1656-
id: part.id,
1657-
name: part.providerName ?? part.name,
1658-
isFailure: part.isFailure,
1659-
result: part.encodedResult
1660-
}))
1661-
break
1605+
break
1606+
}
1607+
case "reasoning-end": {
1608+
if (activeReasoningDeltas.has(part.id)) {
1609+
assistantParts.push(makePart("reasoning", activeReasoningDeltas.get(part.id)!))
16621610
}
1611+
break
1612+
}
1613+
1614+
// Tool Call Parts
1615+
case "tool-call": {
1616+
assistantParts.push(makePart("tool-call", {
1617+
id: part.id,
1618+
name: part.providerName ?? part.name,
1619+
params: part.params,
1620+
providerExecuted: part.providerExecuted ?? false
1621+
}))
1622+
break
1623+
}
1624+
1625+
// Tool Result Parts
1626+
case "tool-result": {
1627+
toolParts.push(makePart("tool-result", {
1628+
id: part.id,
1629+
name: part.providerName ?? part.name,
1630+
isFailure: part.isFailure,
1631+
result: part.encodedResult
1632+
}))
16631633
}
16641634
}
16651635
}
16661636

1667-
flushDeltas()
1668-
16691637
if (assistantParts.length === 0 && toolParts.length === 0) {
16701638
return empty
16711639
}
16721640

16731641
const messages: Array<Message> = []
1642+
16741643
if (assistantParts.length > 0) {
16751644
messages.push(makeMessage("assistant", { content: assistantParts }))
16761645
}
1646+
16771647
if (toolParts.length > 0) {
16781648
messages.push(makeMessage("tool", { content: toolParts }))
16791649
}
1650+
16801651
return makePrompt(messages)
16811652
}
16821653

0 commit comments

Comments
 (0)