Skip to content

Commit 4754b1e

Browse files
authored
fix: fix openai stream chunk marshaling (#35)
Fixes chunk marshaling for openai stream interceptions so no empty values are added.
1 parent fa62391 commit 4754b1e

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ require (
1010
github.com/mark3labs/mcp-go v0.38.0
1111
github.com/stretchr/testify v1.10.0
1212
github.com/tidwall/gjson v1.18.0 // indirect
13-
github.com/tidwall/sjson v1.2.5 // indirect
13+
github.com/tidwall/sjson v1.2.5
1414
go.uber.org/goleak v1.3.0
1515
go.uber.org/mock v0.6.0
1616
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b

intercept_openai_chat_streaming.go

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/google/uuid"
1515
"github.com/openai/openai-go/v2"
1616
"github.com/openai/openai-go/v2/packages/ssestream"
17+
"github.com/tidwall/sjson"
1718

1819
"cdr.dev/slog"
1920
)
@@ -117,16 +118,8 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
117118
continue
118119
}
119120

120-
// If usage information is available, relay the cumulative usage once all tool invocations have completed.
121-
if chunk.Usage.CompletionTokens > 0 {
122-
chunk.Usage = processor.getCumulativeUsage()
123-
}
124-
125-
// Overwrite response identifier since proxy obscures injected tool call invocations.
126-
chunk.ID = i.ID().String()
127-
128121
// Marshal and relay chunk to client.
129-
payload, err := i.marshal(chunk)
122+
payload, err := i.marshalChunk(&chunk, i.ID(), processor)
130123
if err != nil {
131124
logger.Warn(ctx, "failed to marshal chunk", slog.Error(err), chunk.RawJSON())
132125
lastErr = fmt.Errorf("marshal chunk: %w", err)
@@ -202,7 +195,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
202195
}
203196

204197
if interceptionErr != nil {
205-
payload, err := i.marshal(interceptionErr)
198+
payload, err := i.marshalErr(interceptionErr)
206199
if err != nil {
207200
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
208201
} else if err := events.Send(streamCtx, payload); err != nil {
@@ -291,10 +284,36 @@ func (i *OpenAIStreamingChatInterception) getInjectedToolByName(name string) *mc
291284
return i.mcpProxy.GetTool(name)
292285
}
293286

294-
func (i *OpenAIStreamingChatInterception) marshal(payload any) ([]byte, error) {
295-
data, err := json.Marshal(payload)
287+
// Mashals received stream chunk.
288+
// Overrides id (since proxy obscures injected tool call invocations).
289+
// If usage field was set in original chunk overrides it to culminative usage.
290+
//
291+
// sjson is used instead of normal struct marshaling so forwarded data
292+
// is as close to the original as possible. Structs from openai library lack
293+
// `omitzero/omitempty` annotations which adds additional empty fields
294+
// when marshaling structs. Those additional empty fields can break Codex client.
295+
func (i *OpenAIStreamingChatInterception) marshalChunk(chunk *openai.ChatCompletionChunk, id uuid.UUID, prc *openAIStreamProcessor) ([]byte, error) {
296+
sj, err := sjson.Set(chunk.RawJSON(), "id", id.String())
297+
if err != nil {
298+
return nil, fmt.Errorf("marshal chunk id failed: %w", err)
299+
}
300+
301+
// If usage information is available, relay the cumulative usage once all tool invocations have completed.
302+
if chunk.JSON.Usage.Valid() {
303+
u := prc.getCumulativeUsage()
304+
sj, err = sjson.Set(sj, "usage", u)
305+
if err != nil {
306+
return nil, fmt.Errorf("marshal chunk usage failed: %w", err)
307+
}
308+
}
309+
310+
return i.encodeForStream([]byte(sj)), nil
311+
}
312+
313+
func (i *OpenAIStreamingChatInterception) marshalErr(err error) ([]byte, error) {
314+
data, err := json.Marshal(err)
296315
if err != nil {
297-
return nil, fmt.Errorf("marshal payload: %w", err)
316+
return nil, fmt.Errorf("marshal error failed: %w", err)
298317
}
299318

300319
return i.encodeForStream(data), nil

0 commit comments

Comments
 (0)