Skip to content

Commit 2333378

Browse files
committed
feat: Handle SSE errors
When an error happens during LLM generation, an error is sent as a message. See https://platform.openai.com/docs/api-reference/responses-streaming/error We now handle it, to transmit the error to the client
1 parent fcb09ea commit 2333378

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed

model/rag/chat.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,17 @@ func Query(inst *instance.Instance, logger logger.Logger, query QueryMessage) er
303303
})
304304

305305
if err != nil {
306+
// Send error event to client
307+
errorDoc := map[string]interface{}{
308+
"_id": msg.ID,
309+
"object": "error",
310+
"message": err.Error(),
311+
}
312+
errorPayload := couchdb.JSONDoc{
313+
Type: consts.ChatEvents,
314+
M: errorDoc,
315+
}
316+
realtime.GetHub().Publish(inst, realtime.EventCreate, &errorPayload, nil)
306317
return err
307318
}
308319

@@ -369,6 +380,18 @@ func foreachSSE(r io.Reader, fn func(event map[string]interface{})) error {
369380
if err := json.Unmarshal(data, &event); err != nil {
370381
return err
371382
}
383+
// Check for error event from the server
384+
if errObj, ok := event["error"].(map[string]interface{}); ok {
385+
message, _ := errObj["message"].(string)
386+
code, _ := errObj["code"].(string)
387+
if message == "" {
388+
message = "unknown streaming error"
389+
}
390+
if code != "" {
391+
return fmt.Errorf("%s: %s", code, message)
392+
}
393+
return errors.New(message)
394+
}
372395
fn(event)
373396
}
374397
return nil

model/rag/chat_test.go

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package rag
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestForeachSSE(t *testing.T) {
12+
t.Run("normal events are passed to callback", func(t *testing.T) {
13+
input := `data: {"object":"chat.completion.chunk","content":"hello"}
14+
15+
data: {"object":"chat.completion.chunk","content":"world"}
16+
17+
data: [DONE]
18+
`
19+
var events []map[string]interface{}
20+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {
21+
events = append(events, event)
22+
})
23+
24+
require.NoError(t, err)
25+
assert.Len(t, events, 2)
26+
assert.Equal(t, "hello", events[0]["content"])
27+
assert.Equal(t, "world", events[1]["content"])
28+
})
29+
30+
t.Run("error event with code and message returns formatted error", func(t *testing.T) {
31+
input := `data: {"error":{"message":"Error while generating answer","code":"ERROR_ANSWER_GENERATION"}}
32+
`
33+
var events []map[string]interface{}
34+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {
35+
events = append(events, event)
36+
})
37+
38+
require.Error(t, err)
39+
assert.Equal(t, "ERROR_ANSWER_GENERATION: Error while generating answer", err.Error())
40+
assert.Empty(t, events)
41+
})
42+
43+
t.Run("error event with only message returns error", func(t *testing.T) {
44+
input := `data: {"error":{"message":"Something went wrong"}}
45+
`
46+
var events []map[string]interface{}
47+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {
48+
events = append(events, event)
49+
})
50+
51+
require.Error(t, err)
52+
assert.Equal(t, "Something went wrong", err.Error())
53+
assert.Empty(t, events)
54+
})
55+
56+
t.Run("error event with empty message returns unknown error", func(t *testing.T) {
57+
input := `data: {"error":{"message":"","code":"SOME_CODE"}}
58+
`
59+
var events []map[string]interface{}
60+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {
61+
events = append(events, event)
62+
})
63+
64+
require.Error(t, err)
65+
assert.Equal(t, "SOME_CODE: unknown streaming error", err.Error())
66+
assert.Empty(t, events)
67+
})
68+
69+
t.Run("error event with no message field returns unknown error", func(t *testing.T) {
70+
input := `data: {"error":{}}
71+
`
72+
var events []map[string]interface{}
73+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {
74+
events = append(events, event)
75+
})
76+
77+
require.Error(t, err)
78+
assert.Equal(t, "unknown streaming error", err.Error())
79+
assert.Empty(t, events)
80+
})
81+
82+
t.Run("DONE stops processing", func(t *testing.T) {
83+
input := `data: {"object":"first"}
84+
85+
data: [DONE]
86+
data: {"object":"should not be processed"}
87+
`
88+
var events []map[string]interface{}
89+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {
90+
events = append(events, event)
91+
})
92+
93+
require.NoError(t, err)
94+
assert.Len(t, events, 1)
95+
assert.Equal(t, "first", events[0]["object"])
96+
})
97+
98+
t.Run("invalid SSE format returns error", func(t *testing.T) {
99+
input := `invalid line without colon
100+
`
101+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {})
102+
103+
require.Error(t, err)
104+
assert.Equal(t, "invalid SSE response", err.Error())
105+
})
106+
107+
t.Run("invalid JSON returns error", func(t *testing.T) {
108+
input := `data: {invalid json}
109+
`
110+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {})
111+
112+
require.Error(t, err)
113+
assert.Contains(t, err.Error(), "invalid character")
114+
})
115+
116+
t.Run("comments are skipped", func(t *testing.T) {
117+
input := `: this is a comment
118+
data: {"object":"event"}
119+
120+
data: [DONE]
121+
`
122+
var events []map[string]interface{}
123+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {
124+
events = append(events, event)
125+
})
126+
127+
require.NoError(t, err)
128+
assert.Len(t, events, 1)
129+
})
130+
131+
t.Run("empty lines are skipped", func(t *testing.T) {
132+
input := `
133+
134+
data: {"object":"event"}
135+
136+
137+
data: [DONE]
138+
`
139+
var events []map[string]interface{}
140+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {
141+
events = append(events, event)
142+
})
143+
144+
require.NoError(t, err)
145+
assert.Len(t, events, 1)
146+
})
147+
148+
t.Run("non-data fields are skipped", func(t *testing.T) {
149+
input := `event: message
150+
id: 123
151+
data: {"object":"event"}
152+
153+
data: [DONE]
154+
`
155+
var events []map[string]interface{}
156+
err := foreachSSE(strings.NewReader(input), func(event map[string]interface{}) {
157+
events = append(events, event)
158+
})
159+
160+
require.NoError(t, err)
161+
assert.Len(t, events, 1)
162+
})
163+
}

0 commit comments

Comments
 (0)