@@ -14,6 +14,7 @@ import (
1414 "github.com/google/uuid"
1515 "github.com/openai/openai-go/v2"
1616 "github.com/openai/openai-go/v2/packages/ssestream"
17+ "github.com/tidwall/sjson"
1718
1819 "cdr.dev/slog"
1920)
@@ -117,16 +118,8 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
117118 continue
118119 }
119120
120- // If usage information is available, relay the cumulative usage once all tool invocations have completed.
121- if chunk .Usage .CompletionTokens > 0 {
122- chunk .Usage = processor .getCumulativeUsage ()
123- }
124-
125- // Overwrite response identifier since proxy obscures injected tool call invocations.
126- chunk .ID = i .ID ().String ()
127-
128121 // Marshal and relay chunk to client.
129- payload , err := i .marshal ( chunk )
122+ payload , err := i .marshalChunk ( & chunk , i . ID (), processor )
130123 if err != nil {
131124 logger .Warn (ctx , "failed to marshal chunk" , slog .Error (err ), chunk .RawJSON ())
132125 lastErr = fmt .Errorf ("marshal chunk: %w" , err )
@@ -202,7 +195,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
202195 }
203196
204197 if interceptionErr != nil {
205- payload , err := i .marshal (interceptionErr )
198+ payload , err := i .marshalErr (interceptionErr )
206199 if err != nil {
207200 logger .Warn (ctx , "failed to marshal error" , slog .Error (err ), slog .F ("error_payload" , slog .F ("%+v" , interceptionErr )))
208201 } else if err := events .Send (streamCtx , payload ); err != nil {
@@ -291,10 +284,36 @@ func (i *OpenAIStreamingChatInterception) getInjectedToolByName(name string) *mc
291284 return i .mcpProxy .GetTool (name )
292285}
293286
294- func (i * OpenAIStreamingChatInterception ) marshal (payload any ) ([]byte , error ) {
295- data , err := json .Marshal (payload )
287+ // Mashals received stream chunk.
288+ // Overrides id (since proxy obscures injected tool call invocations).
289+ // If usage field was set in original chunk overrides it to culminative usage.
290+ //
291+ // sjson is used instead of normal struct marshaling so forwarded data
292+ // is as close to the original as possible. Structs from openai library lack
293+ // `omitzero/omitempty` annotations which adds additional empty fields
294+ // when marshaling structs. Those additional empty fields can break Codex client.
295+ func (i * OpenAIStreamingChatInterception ) marshalChunk (chunk * openai.ChatCompletionChunk , id uuid.UUID , prc * openAIStreamProcessor ) ([]byte , error ) {
296+ sj , err := sjson .Set (chunk .RawJSON (), "id" , id .String ())
297+ if err != nil {
298+ return nil , fmt .Errorf ("marshal chunk id failed: %w" , err )
299+ }
300+
301+ // If usage information is available, relay the cumulative usage once all tool invocations have completed.
302+ if chunk .JSON .Usage .Valid () {
303+ u := prc .getCumulativeUsage ()
304+ sj , err = sjson .Set (sj , "usage" , u )
305+ if err != nil {
306+ return nil , fmt .Errorf ("marshal chunk usage failed: %w" , err )
307+ }
308+ }
309+
310+ return i .encodeForStream ([]byte (sj )), nil
311+ }
312+
313+ func (i * OpenAIStreamingChatInterception ) marshalErr (err error ) ([]byte , error ) {
314+ data , err := json .Marshal (err )
296315 if err != nil {
297- return nil , fmt .Errorf ("marshal payload : %w" , err )
316+ return nil , fmt .Errorf ("marshal error failed : %w" , err )
298317 }
299318
300319 return i .encodeForStream (data ), nil
0 commit comments