Skip to content

Commit 0c6dfd9

Browse files
Dev AgentQinYuuuu
authored andcommitted
optimize ChatCompletionRequest JSON marshal and unmarshal,support unknown data
1 parent 7029ed4 commit 0c6dfd9

File tree

2 files changed

+243
-30
lines changed

2 files changed

+243
-30
lines changed

aigateway/handler/requests.go

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package handler
22

33
import (
4+
"encoding/json"
5+
46
"github.com/openai/openai-go/v3"
57
)
68

@@ -29,43 +31,90 @@ type ChatCompletionRequest struct {
2931
MaxTokens int `json:"max_tokens,omitempty"`
3032
Stream bool `json:"stream,omitempty"`
3133
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
34+
// RawJSON stores all unknown fields during unmarshaling
35+
RawJSON json.RawMessage `json:"-"`
3236
}
3337

34-
// ChatMessage represents a chat message
35-
type ChatMessage struct {
36-
Role string `json:"role"`
37-
Content string `json:"content"`
38-
// The tool calls generated by the model, such as function calls.
39-
ToolCalls []openai.ChatCompletionMessageToolCallUnionParam `json:"tool_calls,omitzero"`
40-
// Tool call that this message is responding to.
41-
ToolCallID string `json:"tool_call_id,omitzero"`
42-
}
38+
// UnmarshalJSON implements json.Unmarshaler interface
39+
func (r *ChatCompletionRequest) UnmarshalJSON(data []byte) error {
40+
// Create a temporary struct to hold the known fields
41+
type TempChatCompletionRequest ChatCompletionRequest
4342

44-
type StreamOptions struct {
45-
IncludeUsage bool `json:"include_usage,omitempty"`
43+
// First, unmarshal into the temporary struct
44+
var temp TempChatCompletionRequest
45+
if err := json.Unmarshal(data, &temp); err != nil {
46+
return err
47+
}
48+
49+
// Then, unmarshal into a map to get all fields
50+
var allFields map[string]json.RawMessage
51+
if err := json.Unmarshal(data, &allFields); err != nil {
52+
return err
53+
}
54+
55+
// Remove known fields from the map
56+
delete(allFields, "model")
57+
delete(allFields, "messages")
58+
delete(allFields, "tool_choice")
59+
delete(allFields, "tools")
60+
delete(allFields, "temperature")
61+
delete(allFields, "max_tokens")
62+
delete(allFields, "stream")
63+
delete(allFields, "stream_options")
64+
65+
// If there are any unknown fields left, marshal them into RawJSON
66+
var rawJSON []byte
67+
var err error
68+
if len(allFields) > 0 {
69+
rawJSON, err = json.Marshal(allFields)
70+
if err != nil {
71+
return err
72+
}
73+
}
74+
75+
// Assign the temporary struct to the original and set RawJSON
76+
*r = ChatCompletionRequest(temp)
77+
r.RawJSON = rawJSON
78+
return nil
4679
}
4780

48-
// ChatCompletionResponse represents a chat completion response
49-
type ChatCompletionResponse struct {
50-
ID string `json:"id"`
51-
Object string `json:"object"`
52-
Created int64 `json:"created"`
53-
Model string `json:"model"`
54-
Choices []struct {
55-
Index int `json:"index"`
56-
Message ChatMessage `json:"message"`
57-
FinishReason string `json:"finish_reason"`
58-
} `json:"choices"`
59-
Usage struct {
60-
PromptTokens int `json:"prompt_tokens"`
61-
CompletionTokens int `json:"completion_tokens"`
62-
TotalTokens int `json:"total_tokens"`
63-
} `json:"usage"`
81+
// MarshalJSON implements json.Marshaler interface
82+
func (r ChatCompletionRequest) MarshalJSON() ([]byte, error) {
83+
// First, marshal the known fields
84+
type TempChatCompletionRequest ChatCompletionRequest
85+
data, err := json.Marshal(TempChatCompletionRequest(r))
86+
if err != nil {
87+
return nil, err
88+
}
89+
90+
// If there are no raw JSON fields, just return the known fields
91+
if len(r.RawJSON) == 0 {
92+
return data, nil
93+
}
94+
95+
// Parse the known fields back into a map
96+
var knownFields map[string]json.RawMessage
97+
if err := json.Unmarshal(data, &knownFields); err != nil {
98+
return nil, err
99+
}
100+
101+
// Parse the raw JSON fields into a map
102+
var rawFields map[string]json.RawMessage
103+
if err := json.Unmarshal(r.RawJSON, &rawFields); err != nil {
104+
return nil, err
105+
}
106+
107+
// Merge the raw fields into the known fields
108+
for k, v := range rawFields {
109+
knownFields[k] = v
110+
}
111+
112+
// Marshal the merged map back into JSON
113+
return json.Marshal(knownFields)
64114
}
65115

66-
// ChatMessageHistoryResponse represents the chat message history response format
67-
type ChatMessageHistoryResponse struct {
68-
Messages []ChatMessage `json:"messages"`
116+
type StreamOptions struct {
117+
IncludeUsage bool `json:"include_usage,omitempty"`
69118
}
70119

71120
// EmbeddingRequest represents an embedding request structure

aigateway/handler/requests_test.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package handler
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/openai/openai-go/v3"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestChatCompletionRequest_MarshalUnmarshal(t *testing.T) {
12+
// Test case 1: Only known fields
13+
req1 := &ChatCompletionRequest{
14+
Model: "gpt-3.5-turbo",
15+
Messages: []openai.ChatCompletionMessageParamUnion{
16+
openai.UserMessage("hello"),
17+
},
18+
Temperature: 0.7,
19+
MaxTokens: 100,
20+
}
21+
22+
// Marshal to JSON
23+
data1, err := json.Marshal(req1)
24+
assert.NoError(t, err)
25+
26+
// Unmarshal back
27+
var req1Unmarshaled ChatCompletionRequest
28+
err = json.Unmarshal(data1, &req1Unmarshaled)
29+
assert.NoError(t, err)
30+
31+
// Verify fields
32+
assert.Equal(t, req1.Model, req1Unmarshaled.Model)
33+
assert.Equal(t, len(req1.Messages), len(req1Unmarshaled.Messages))
34+
assert.Equal(t, req1.Temperature, req1Unmarshaled.Temperature)
35+
assert.Equal(t, req1.MaxTokens, req1Unmarshaled.MaxTokens)
36+
assert.Empty(t, req1Unmarshaled.RawJSON)
37+
}
38+
39+
func TestChatCompletionRequest_UnknownFields(t *testing.T) {
40+
// Test case 2: With unknown fields
41+
jsonWithUnknown := `{
42+
"model": "gpt-3.5-turbo",
43+
"messages": [
44+
{"role": "user", "content": "Hello"}
45+
],
46+
"temperature": 0.7,
47+
"max_tokens": 100,
48+
"unknown_field": "unknown_value",
49+
"another_unknown": 12345
50+
}`
51+
52+
// Unmarshal
53+
var req2 ChatCompletionRequest
54+
err := json.Unmarshal([]byte(jsonWithUnknown), &req2)
55+
assert.NoError(t, err)
56+
57+
// Verify known fields
58+
assert.Equal(t, "gpt-3.5-turbo", req2.Model)
59+
assert.Equal(t, 0.7, req2.Temperature)
60+
assert.Equal(t, 100, req2.MaxTokens)
61+
62+
// Verify unknown fields are stored in RawJSON
63+
assert.NotEmpty(t, req2.RawJSON)
64+
65+
// Marshal back and verify unknown fields are preserved
66+
data2, err := json.Marshal(req2)
67+
assert.NoError(t, err)
68+
69+
// Unmarshal into map to check all fields
70+
var resultMap map[string]interface{}
71+
err = json.Unmarshal(data2, &resultMap)
72+
assert.NoError(t, err)
73+
74+
// Check known fields
75+
assert.Equal(t, "gpt-3.5-turbo", resultMap["model"])
76+
assert.Equal(t, 0.7, resultMap["temperature"])
77+
assert.Equal(t, 100.0, resultMap["max_tokens"])
78+
79+
// Check unknown fields
80+
assert.Equal(t, "unknown_value", resultMap["unknown_field"])
81+
assert.Equal(t, 12345.0, resultMap["another_unknown"])
82+
}
83+
84+
func TestChatCompletionRequest_ComplexUnknownFields(t *testing.T) {
85+
// Test case 3: With complex unknown fields
86+
jsonWithComplexUnknown := `{
87+
"model": "gpt-3.5-turbo",
88+
"messages": [
89+
{"role": "user", "content": "Hello"}
90+
],
91+
"stream": true,
92+
"complex_field": {
93+
"nested1": "value1",
94+
"nested2": {
95+
"deep": 123
96+
}
97+
},
98+
"array_field": [1, 2, 3, 4, 5]
99+
}`
100+
101+
// Unmarshal
102+
var req3 ChatCompletionRequest
103+
err := json.Unmarshal([]byte(jsonWithComplexUnknown), &req3)
104+
assert.NoError(t, err)
105+
106+
// Verify known fields
107+
assert.Equal(t, "gpt-3.5-turbo", req3.Model)
108+
assert.True(t, req3.Stream)
109+
110+
// Verify unknown fields are stored
111+
assert.NotEmpty(t, req3.RawJSON)
112+
113+
// Marshal back and verify all fields are preserved
114+
data3, err := json.Marshal(req3)
115+
assert.NoError(t, err)
116+
117+
// Unmarshal into map to check
118+
var resultMap map[string]interface{}
119+
err = json.Unmarshal(data3, &resultMap)
120+
assert.NoError(t, err)
121+
122+
// Check known fields
123+
assert.Equal(t, "gpt-3.5-turbo", resultMap["model"])
124+
assert.True(t, resultMap["stream"].(bool))
125+
126+
// Check complex unknown fields
127+
complexField, ok := resultMap["complex_field"].(map[string]interface{})
128+
assert.True(t, ok)
129+
assert.Equal(t, "value1", complexField["nested1"])
130+
131+
nested2, ok := complexField["nested2"].(map[string]interface{})
132+
assert.True(t, ok)
133+
assert.Equal(t, 123.0, nested2["deep"])
134+
135+
// Check array unknown field
136+
arrayField, ok := resultMap["array_field"].([]interface{})
137+
assert.True(t, ok)
138+
assert.Len(t, arrayField, 5)
139+
assert.Equal(t, 1.0, arrayField[0])
140+
assert.Equal(t, 5.0, arrayField[4])
141+
}
142+
143+
func TestChatCompletionRequest_EmptyRawJSON(t *testing.T) {
144+
// Test case 4: Empty RawJSON should not cause issues
145+
req4 := &ChatCompletionRequest{
146+
Model: "gpt-3.5-turbo",
147+
Messages: []openai.ChatCompletionMessageParamUnion{
148+
openai.UserMessage("hello"),
149+
},
150+
RawJSON: nil,
151+
}
152+
153+
// Marshal should work fine
154+
data4, err := json.Marshal(req4)
155+
assert.NoError(t, err)
156+
157+
// Unmarshal should work fine
158+
var req4Unmarshaled ChatCompletionRequest
159+
err = json.Unmarshal(data4, &req4Unmarshaled)
160+
assert.NoError(t, err)
161+
162+
// RawJSON should be empty
163+
assert.Empty(t, req4Unmarshaled.RawJSON)
164+
}

0 commit comments

Comments
 (0)