Skip to content

Commit 95c2c78

Browse files
committed
convertStreamingResponse return error in case of error instead of int
1 parent a17895e commit 95c2c78

File tree

1 file changed

+76
-20
lines changed

1 file changed

+76
-20
lines changed

pkg/metrics/openai_recorder.go

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package metrics
33
import (
44
"bytes"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"net/http"
89
"strings"
@@ -18,6 +19,28 @@ import (
1819
// per model.
1920
const maximumRecordsPerModel = 10
2021

22+
// StreamingError represents an error that occurred during streaming response processing.
23+
// It contains the HTTP status code and additional context about the error.
24+
type StreamingError struct {
25+
StatusCode int `json:"status_code"`
26+
Message string `json:"message"`
27+
Type string `json:"type,omitempty"`
28+
Details string `json:"details,omitempty"`
29+
}
30+
31+
// Error implements the error interface for StreamingError.
32+
func (e *StreamingError) Error() string {
33+
if e.Type != "" {
34+
return fmt.Sprintf("streaming error (code %d, type %s): %s", e.StatusCode, e.Type, e.Message)
35+
}
36+
return fmt.Sprintf("streaming error (code %d): %s", e.StatusCode, e.Message)
37+
}
38+
39+
// StatusCode returns the HTTP status code associated with this streaming error.
40+
func (e *StreamingError) GetStatusCode() int {
41+
return e.StatusCode
42+
}
43+
2144
type responseRecorder struct {
2245
http.ResponseWriter
2346
body *bytes.Buffer
@@ -191,13 +214,9 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
191214
}
192215

193216
var response string
194-
var streamingErrorCode int
217+
var streamingErr error
195218
if strings.Contains(responseBody, "data: ") {
196-
response, streamingErrorCode = r.convertStreamingResponse(responseBody)
197-
// If streaming error detected, use that status code instead
198-
if streamingErrorCode >= 400 {
199-
statusCode = streamingErrorCode
200-
}
219+
response, streamingErr = r.convertStreamingResponse(responseBody)
201220
} else {
202221
response = responseBody
203222
}
@@ -211,8 +230,23 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
211230
for _, record := range modelData.Records {
212231
if record.ID == id {
213232
record.StatusCode = statusCode
214-
// Populate either Response or Error field based on status code
215-
if statusCode >= 400 {
233+
if streamingErr != nil {
234+
// Check if it's a StreamingError and serialize it directly to JSON
235+
var streamingError *StreamingError
236+
if errors.As(streamingErr, &streamingError) {
237+
// Serialize StreamingError directly to JSON for a rich error structure
238+
if errorJSON, err := json.Marshal(streamingError); err == nil {
239+
record.Error = string(errorJSON)
240+
} else {
241+
// Fallback to normalized JSON if marshaling fails
242+
record.Error = r.normalizeErrorToJSON(streamingErr.Error())
243+
}
244+
} else {
245+
// For non-StreamingError types, use the normalized approach
246+
record.Error = r.normalizeErrorToJSON(streamingErr.Error())
247+
}
248+
record.Response = "" // Ensure Response is empty for errors
249+
} else if statusCode >= 400 {
216250
record.Error = r.normalizeErrorToJSON(response)
217251
record.Response = "" // Ensure Response is empty for errors
218252
} else {
@@ -228,29 +262,51 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
228262
}
229263
}
230264

231-
func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) (string, int) {
265+
func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) (string, error) {
232266
lines := strings.Split(streamingBody, "\n")
233267
var contentBuilder strings.Builder
234268
var reasoningContentBuilder strings.Builder
235269
var lastChoice, lastChunk map[string]interface{}
236-
errorStatusCode := 0 // 0 means no error detected
237270

238271
for _, line := range lines {
239-
// Check for error lines in streaming format
272+
// Check for error lines in the streaming format
240273
if strings.HasPrefix(line, "error: ") {
241274
errorData := strings.TrimPrefix(line, "error: ")
242275
var errorObj map[string]interface{}
243276
if err := json.Unmarshal([]byte(errorData), &errorObj); err == nil {
277+
// Create a StreamingError with extracted information
278+
streamingErr := &StreamingError{
279+
StatusCode: 400, // Default status code
280+
Message: "streaming error",
281+
}
282+
244283
// Extract error code if available
245284
if code, ok := errorObj["code"].(float64); ok {
246-
errorStatusCode = int(code)
247-
} else {
248-
// Default to 400 if no specific code found
249-
errorStatusCode = 400
285+
streamingErr.StatusCode = int(code)
286+
}
287+
288+
// Extract error message if available
289+
if message, ok := errorObj["message"].(string); ok {
290+
streamingErr.Message = message
250291
}
292+
293+
// Extract error type if available
294+
if errorType, ok := errorObj["type"].(string); ok {
295+
streamingErr.Type = errorType
296+
}
297+
298+
// Store original error data as details
299+
streamingErr.Details = errorData
300+
301+
// Return the original streaming body for error cases
302+
return streamingBody, streamingErr
303+
}
304+
// If we can't parse the error JSON, create a generic error
305+
return streamingBody, &StreamingError{
306+
StatusCode: 400,
307+
Message: "unparseable streaming error",
308+
Details: errorData,
251309
}
252-
// Return the original streaming body for error cases
253-
return streamingBody, errorStatusCode
254310
}
255311

256312
if strings.HasPrefix(line, "data: ") {
@@ -283,7 +339,7 @@ func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) (string,
283339
}
284340

285341
if lastChunk == nil {
286-
return streamingBody, errorStatusCode
342+
return streamingBody, nil
287343
}
288344

289345
finalResponse := make(map[string]interface{})
@@ -315,10 +371,10 @@ func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) (string,
315371

316372
jsonResult, err := json.Marshal(finalResponse)
317373
if err != nil {
318-
return streamingBody, errorStatusCode
374+
return streamingBody, nil
319375
}
320376

321-
return string(jsonResult), errorStatusCode
377+
return string(jsonResult), nil
322378
}
323379

324380
func (r *OpenAIRecorder) GetRecordsHandler() http.HandlerFunc {

0 commit comments

Comments
 (0)