diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index 536eb2e15..336cf99ec 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -12,7 +12,7 @@ import ( "strings" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" + "github.com/openai/openai-go/v3" "opencsg.com/csghub-server/aigateway/component" "opencsg.com/csghub-server/aigateway/token" "opencsg.com/csghub-server/aigateway/types" @@ -166,6 +166,9 @@ func (h *OpenAIHandlerImpl) GetModel(c *gin.Context) { c.PureJSON(http.StatusOK, model) } +var _ openai.ChatCompletion +var _ openai.ChatCompletionChunk + // Chat godoc // @Security ApiKey // @Summary Chat with backend model @@ -174,7 +177,8 @@ func (h *OpenAIHandlerImpl) GetModel(c *gin.Context) { // @Accept json // @Produce json // @Param request body ChatCompletionRequest true "Chat completion request" -// @Success 200 {object} ChatCompletionResponse "OK" +// @Success 200 {object} openai.ChatCompletion "OK" +// @Success 200 {object} openai.ChatCompletionChunk "OK" // @Failure 400 {object} error "Bad request" // @Failure 404 {object} error "Model not found" // @Failure 500 {object} error "Internal server error" @@ -189,24 +193,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { username := httpbase.GetCurrentUser(c) userUUID := httpbase.GetCurrentUserUUID(c) chatReq := &ChatCompletionRequest{} - bodyBytes, err := io.ReadAll(c.Request.Body) - if err != nil { - slog.Error("failed to read request body", "error", err.Error()) - c.String(http.StatusBadRequest, fmt.Errorf("invalid chat compoletion request body:%w", err).Error()) - return - } - - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - c.Request.ContentLength = int64(len(bodyBytes)) - - if err = json.Unmarshal(bodyBytes, chatReq); err != nil { - slog.Error("failed to parse request body", "error", err.Error()) - c.String(http.StatusBadRequest, fmt.Errorf("invalid chat compoletion request body:%w", err).Error()) - return - } - - validate := validator.New() - if err = validate.Struct(chatReq); err != nil { + if err := c.BindJSON(chatReq); err != nil { slog.Error("invalid chat compoletion request body", "error", err.Error()) c.String(http.StatusBadRequest, fmt.Errorf("invalid chat compoletion request body:%w", err).Error()) return @@ -263,16 +250,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { c.String(http.StatusBadRequest, err.Error()) return } - - var reqMap map[string]interface{} - if err = json.Unmarshal(bodyBytes, &reqMap); err != nil { - slog.Error("failed to unmarshal request body to map", "error", err) - c.String(http.StatusBadRequest, fmt.Errorf("invalid chat completion request body: %w", err).Error()) - return - } - // directly update model field in request map - reqMap["model"] = modelName - + chatReq.Model = modelName if chatReq.Stream { c.Writer.Header().Set("Content-Type", "text/event-stream") if !strings.Contains(model.ImageID, "vllm-cpu") { @@ -283,13 +261,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { } // marshal updated request map back to JSON bytes - updatedBodyBytes, err := json.Marshal(reqMap) - if err != nil { - slog.Error("failed to marshal updated request map", "error", err) - c.String(http.StatusInternalServerError, fmt.Errorf("failed to process chat request: %w", err).Error()) - return - } - + updatedBodyBytes, _ := json.Marshal(chatReq) c.Request.Body = io.NopCloser(bytes.NewReader(updatedBodyBytes)) c.Request.ContentLength = int64(len(updatedBodyBytes)) rp, _ := proxy.NewReverseProxy(target) diff --git a/aigateway/handler/requests.go b/aigateway/handler/requests.go index 0b376cc14..75fc8aef2 100644 --- a/aigateway/handler/requests.go +++ b/aigateway/handler/requests.go @@ -1,6 +1,8 @@ package handler import ( + "encoding/json" + "github.com/openai/openai-go/v3" ) @@ -29,43 +31,90 @@ type ChatCompletionRequest struct { MaxTokens int `json:"max_tokens,omitempty"` Stream bool `json:"stream,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // RawJSON stores all unknown fields during unmarshaling + RawJSON json.RawMessage `json:"-"` } -// ChatMessage represents a chat message -type ChatMessage struct { - Role string `json:"role"` - Content string `json:"content"` - // The tool calls generated by the model, such as function calls. - ToolCalls []openai.ChatCompletionMessageToolCallUnionParam `json:"tool_calls,omitzero"` - // Tool call that this message is responding to. - ToolCallID string `json:"tool_call_id,omitzero"` -} +// UnmarshalJSON implements json.Unmarshaler interface +func (r *ChatCompletionRequest) UnmarshalJSON(data []byte) error { + // Create a temporary struct to hold the known fields + type TempChatCompletionRequest ChatCompletionRequest -type StreamOptions struct { - IncludeUsage bool `json:"include_usage,omitempty"` + // First, unmarshal into the temporary struct + var temp TempChatCompletionRequest + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Then, unmarshal into a map to get all fields + var allFields map[string]json.RawMessage + if err := json.Unmarshal(data, &allFields); err != nil { + return err + } + + // Remove known fields from the map + delete(allFields, "model") + delete(allFields, "messages") + delete(allFields, "tool_choice") + delete(allFields, "tools") + delete(allFields, "temperature") + delete(allFields, "max_tokens") + delete(allFields, "stream") + delete(allFields, "stream_options") + + // If there are any unknown fields left, marshal them into RawJSON + var rawJSON []byte + var err error + if len(allFields) > 0 { + rawJSON, err = json.Marshal(allFields) + if err != nil { + return err + } + } + + // Assign the temporary struct to the original and set RawJSON + *r = ChatCompletionRequest(temp) + r.RawJSON = rawJSON + return nil } -// ChatCompletionResponse represents a chat completion response -type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Message ChatMessage `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` +// MarshalJSON implements json.Marshaler interface +func (r ChatCompletionRequest) MarshalJSON() ([]byte, error) { + // First, marshal the known fields + type TempChatCompletionRequest ChatCompletionRequest + data, err := json.Marshal(TempChatCompletionRequest(r)) + if err != nil { + return nil, err + } + + // If there are no raw JSON fields, just return the known fields + if len(r.RawJSON) == 0 { + return data, nil + } + + // Parse the known fields back into a map + var knownFields map[string]json.RawMessage + if err := json.Unmarshal(data, &knownFields); err != nil { + return nil, err + } + + // Parse the raw JSON fields into a map + var rawFields map[string]json.RawMessage + if err := json.Unmarshal(r.RawJSON, &rawFields); err != nil { + return nil, err + } + + // Merge the raw fields into the known fields + for k, v := range rawFields { + knownFields[k] = v + } + + // Marshal the merged map back into JSON + return json.Marshal(knownFields) } -// ChatMessageHistoryResponse represents the chat message history response format -type ChatMessageHistoryResponse struct { - Messages []ChatMessage `json:"messages"` +type StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` } // EmbeddingRequest represents an embedding request structure diff --git a/aigateway/handler/requests_test.go b/aigateway/handler/requests_test.go new file mode 100644 index 000000000..8b6d971cb --- /dev/null +++ b/aigateway/handler/requests_test.go @@ -0,0 +1,164 @@ +package handler + +import ( + "encoding/json" + "testing" + + "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/assert" +) + +func TestChatCompletionRequest_MarshalUnmarshal(t *testing.T) { + // Test case 1: Only known fields + req1 := &ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + }, + Temperature: 0.7, + MaxTokens: 100, + } + + // Marshal to JSON + data1, err := json.Marshal(req1) + assert.NoError(t, err) + + // Unmarshal back + var req1Unmarshaled ChatCompletionRequest + err = json.Unmarshal(data1, &req1Unmarshaled) + assert.NoError(t, err) + + // Verify fields + assert.Equal(t, req1.Model, req1Unmarshaled.Model) + assert.Equal(t, len(req1.Messages), len(req1Unmarshaled.Messages)) + assert.Equal(t, req1.Temperature, req1Unmarshaled.Temperature) + assert.Equal(t, req1.MaxTokens, req1Unmarshaled.MaxTokens) + assert.Empty(t, req1Unmarshaled.RawJSON) +} + +func TestChatCompletionRequest_UnknownFields(t *testing.T) { + // Test case 2: With unknown fields + jsonWithUnknown := `{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "temperature": 0.7, + "max_tokens": 100, + "unknown_field": "unknown_value", + "another_unknown": 12345 + }` + + // Unmarshal + var req2 ChatCompletionRequest + err := json.Unmarshal([]byte(jsonWithUnknown), &req2) + assert.NoError(t, err) + + // Verify known fields + assert.Equal(t, "gpt-3.5-turbo", req2.Model) + assert.Equal(t, 0.7, req2.Temperature) + assert.Equal(t, 100, req2.MaxTokens) + + // Verify unknown fields are stored in RawJSON + assert.NotEmpty(t, req2.RawJSON) + + // Marshal back and verify unknown fields are preserved + data2, err := json.Marshal(req2) + assert.NoError(t, err) + + // Unmarshal into map to check all fields + var resultMap map[string]interface{} + err = json.Unmarshal(data2, &resultMap) + assert.NoError(t, err) + + // Check known fields + assert.Equal(t, "gpt-3.5-turbo", resultMap["model"]) + assert.Equal(t, 0.7, resultMap["temperature"]) + assert.Equal(t, 100.0, resultMap["max_tokens"]) + + // Check unknown fields + assert.Equal(t, "unknown_value", resultMap["unknown_field"]) + assert.Equal(t, 12345.0, resultMap["another_unknown"]) +} + +func TestChatCompletionRequest_ComplexUnknownFields(t *testing.T) { + // Test case 3: With complex unknown fields + jsonWithComplexUnknown := `{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "stream": true, + "complex_field": { + "nested1": "value1", + "nested2": { + "deep": 123 + } + }, + "array_field": [1, 2, 3, 4, 5] + }` + + // Unmarshal + var req3 ChatCompletionRequest + err := json.Unmarshal([]byte(jsonWithComplexUnknown), &req3) + assert.NoError(t, err) + + // Verify known fields + assert.Equal(t, "gpt-3.5-turbo", req3.Model) + assert.True(t, req3.Stream) + + // Verify unknown fields are stored + assert.NotEmpty(t, req3.RawJSON) + + // Marshal back and verify all fields are preserved + data3, err := json.Marshal(req3) + assert.NoError(t, err) + + // Unmarshal into map to check + var resultMap map[string]interface{} + err = json.Unmarshal(data3, &resultMap) + assert.NoError(t, err) + + // Check known fields + assert.Equal(t, "gpt-3.5-turbo", resultMap["model"]) + assert.True(t, resultMap["stream"].(bool)) + + // Check complex unknown fields + complexField, ok := resultMap["complex_field"].(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t, "value1", complexField["nested1"]) + + nested2, ok := complexField["nested2"].(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t, 123.0, nested2["deep"]) + + // Check array unknown field + arrayField, ok := resultMap["array_field"].([]interface{}) + assert.True(t, ok) + assert.Len(t, arrayField, 5) + assert.Equal(t, 1.0, arrayField[0]) + assert.Equal(t, 5.0, arrayField[4]) +} + +func TestChatCompletionRequest_EmptyRawJSON(t *testing.T) { + // Test case 4: Empty RawJSON should not cause issues + req4 := &ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + }, + RawJSON: nil, + } + + // Marshal should work fine + data4, err := json.Marshal(req4) + assert.NoError(t, err) + + // Unmarshal should work fine + var req4Unmarshaled ChatCompletionRequest + err = json.Unmarshal(data4, &req4Unmarshaled) + assert.NoError(t, err) + + // RawJSON should be empty + assert.Empty(t, req4Unmarshaled.RawJSON) +}