Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 9 additions & 37 deletions aigateway/handler/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"strings"

"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"github.com/openai/openai-go/v3"
"opencsg.com/csghub-server/aigateway/component"
"opencsg.com/csghub-server/aigateway/token"
"opencsg.com/csghub-server/aigateway/types"
Expand Down Expand Up @@ -166,6 +166,9 @@ func (h *OpenAIHandlerImpl) GetModel(c *gin.Context) {
c.PureJSON(http.StatusOK, model)
}

var _ openai.ChatCompletion
var _ openai.ChatCompletionChunk

// Chat godoc
// @Security ApiKey
// @Summary Chat with backend model
Expand All @@ -174,7 +177,8 @@ func (h *OpenAIHandlerImpl) GetModel(c *gin.Context) {
// @Accept json
// @Produce json
// @Param request body ChatCompletionRequest true "Chat completion request"
// @Success 200 {object} ChatCompletionResponse "OK"
// @Success 200 {object} openai.ChatCompletion "OK"
// @Success 200 {object} openai.ChatCompletionChunk "OK"
// @Failure 400 {object} error "Bad request"
// @Failure 404 {object} error "Model not found"
// @Failure 500 {object} error "Internal server error"
Expand All @@ -189,24 +193,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) {
username := httpbase.GetCurrentUser(c)
userUUID := httpbase.GetCurrentUserUUID(c)
chatReq := &ChatCompletionRequest{}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
slog.Error("failed to read request body", "error", err.Error())
c.String(http.StatusBadRequest, fmt.Errorf("invalid chat compoletion request body:%w", err).Error())
return
}

c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
c.Request.ContentLength = int64(len(bodyBytes))

if err = json.Unmarshal(bodyBytes, chatReq); err != nil {
slog.Error("failed to parse request body", "error", err.Error())
c.String(http.StatusBadRequest, fmt.Errorf("invalid chat compoletion request body:%w", err).Error())
return
}

validate := validator.New()
if err = validate.Struct(chatReq); err != nil {
if err := c.BindJSON(chatReq); err != nil {
slog.Error("invalid chat compoletion request body", "error", err.Error())
c.String(http.StatusBadRequest, fmt.Errorf("invalid chat compoletion request body:%w", err).Error())
return
Expand Down Expand Up @@ -263,16 +250,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) {
c.String(http.StatusBadRequest, err.Error())
return
}

var reqMap map[string]interface{}
if err = json.Unmarshal(bodyBytes, &reqMap); err != nil {
slog.Error("failed to unmarshal request body to map", "error", err)
c.String(http.StatusBadRequest, fmt.Errorf("invalid chat completion request body: %w", err).Error())
return
}
// directly update model field in request map
reqMap["model"] = modelName

chatReq.Model = modelName
if chatReq.Stream {
c.Writer.Header().Set("Content-Type", "text/event-stream")
if !strings.Contains(model.ImageID, "vllm-cpu") {
Expand All @@ -283,13 +261,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) {
}

// marshal updated request map back to JSON bytes
updatedBodyBytes, err := json.Marshal(reqMap)
if err != nil {
slog.Error("failed to marshal updated request map", "error", err)
c.String(http.StatusInternalServerError, fmt.Errorf("failed to process chat request: %w", err).Error())
return
}

updatedBodyBytes, _ := json.Marshal(chatReq)
c.Request.Body = io.NopCloser(bytes.NewReader(updatedBodyBytes))
c.Request.ContentLength = int64(len(updatedBodyBytes))
rp, _ := proxy.NewReverseProxy(target)
Expand Down
109 changes: 79 additions & 30 deletions aigateway/handler/requests.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package handler

import (
"encoding/json"

"github.com/openai/openai-go/v3"
)

Expand Down Expand Up @@ -29,43 +31,90 @@ type ChatCompletionRequest struct {
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
// RawJSON stores all unknown fields during unmarshaling
RawJSON json.RawMessage `json:"-"`
}

// ChatMessage represents a chat message
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
// The tool calls generated by the model, such as function calls.
ToolCalls []openai.ChatCompletionMessageToolCallUnionParam `json:"tool_calls,omitzero"`
// Tool call that this message is responding to.
ToolCallID string `json:"tool_call_id,omitzero"`
}
// UnmarshalJSON implements json.Unmarshaler interface
func (r *ChatCompletionRequest) UnmarshalJSON(data []byte) error {
// Create a temporary struct to hold the known fields
type TempChatCompletionRequest ChatCompletionRequest

type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
// First, unmarshal into the temporary struct
var temp TempChatCompletionRequest
if err := json.Unmarshal(data, &temp); err != nil {
return err
}

// Then, unmarshal into a map to get all fields
var allFields map[string]json.RawMessage
if err := json.Unmarshal(data, &allFields); err != nil {
return err
}

// Remove known fields from the map
delete(allFields, "model")
delete(allFields, "messages")
delete(allFields, "tool_choice")
delete(allFields, "tools")
delete(allFields, "temperature")
delete(allFields, "max_tokens")
delete(allFields, "stream")
delete(allFields, "stream_options")

// If there are any unknown fields left, marshal them into RawJSON
var rawJSON []byte
var err error
if len(allFields) > 0 {
rawJSON, err = json.Marshal(allFields)
if err != nil {
return err
}
}

// Assign the temporary struct to the original and set RawJSON
*r = ChatCompletionRequest(temp)
r.RawJSON = rawJSON
return nil
}

// ChatCompletionResponse represents a chat completion response
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
// MarshalJSON implements json.Marshaler interface
func (r ChatCompletionRequest) MarshalJSON() ([]byte, error) {
// First, marshal the known fields
type TempChatCompletionRequest ChatCompletionRequest
data, err := json.Marshal(TempChatCompletionRequest(r))
if err != nil {
return nil, err
}

// If there are no raw JSON fields, just return the known fields
if len(r.RawJSON) == 0 {
return data, nil
}

// Parse the known fields back into a map
var knownFields map[string]json.RawMessage
if err := json.Unmarshal(data, &knownFields); err != nil {
return nil, err
}

// Parse the raw JSON fields into a map
var rawFields map[string]json.RawMessage
if err := json.Unmarshal(r.RawJSON, &rawFields); err != nil {
return nil, err
}

// Merge the raw fields into the known fields
for k, v := range rawFields {
knownFields[k] = v
}

// Marshal the merged map back into JSON
return json.Marshal(knownFields)
}

// ChatMessageHistoryResponse represents the chat message history response format
type ChatMessageHistoryResponse struct {
Messages []ChatMessage `json:"messages"`
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}

// EmbeddingRequest represents an embedding request structure
Expand Down
164 changes: 164 additions & 0 deletions aigateway/handler/requests_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package handler

import (
"encoding/json"
"testing"

"github.com/openai/openai-go/v3"
"github.com/stretchr/testify/assert"
)

func TestChatCompletionRequest_MarshalUnmarshal(t *testing.T) {
// Test case 1: Only known fields
req1 := &ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("hello"),
},
Temperature: 0.7,
MaxTokens: 100,
}

// Marshal to JSON
data1, err := json.Marshal(req1)
assert.NoError(t, err)

// Unmarshal back
var req1Unmarshaled ChatCompletionRequest
err = json.Unmarshal(data1, &req1Unmarshaled)
assert.NoError(t, err)

// Verify fields
assert.Equal(t, req1.Model, req1Unmarshaled.Model)
assert.Equal(t, len(req1.Messages), len(req1Unmarshaled.Messages))
assert.Equal(t, req1.Temperature, req1Unmarshaled.Temperature)
assert.Equal(t, req1.MaxTokens, req1Unmarshaled.MaxTokens)
assert.Empty(t, req1Unmarshaled.RawJSON)
}

func TestChatCompletionRequest_UnknownFields(t *testing.T) {
// Test case 2: With unknown fields
jsonWithUnknown := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
],
"temperature": 0.7,
"max_tokens": 100,
"unknown_field": "unknown_value",
"another_unknown": 12345
}`

// Unmarshal
var req2 ChatCompletionRequest
err := json.Unmarshal([]byte(jsonWithUnknown), &req2)
assert.NoError(t, err)

// Verify known fields
assert.Equal(t, "gpt-3.5-turbo", req2.Model)
assert.Equal(t, 0.7, req2.Temperature)
assert.Equal(t, 100, req2.MaxTokens)

// Verify unknown fields are stored in RawJSON
assert.NotEmpty(t, req2.RawJSON)

// Marshal back and verify unknown fields are preserved
data2, err := json.Marshal(req2)
assert.NoError(t, err)

// Unmarshal into map to check all fields
var resultMap map[string]interface{}
err = json.Unmarshal(data2, &resultMap)
assert.NoError(t, err)

// Check known fields
assert.Equal(t, "gpt-3.5-turbo", resultMap["model"])
assert.Equal(t, 0.7, resultMap["temperature"])
assert.Equal(t, 100.0, resultMap["max_tokens"])

// Check unknown fields
assert.Equal(t, "unknown_value", resultMap["unknown_field"])
assert.Equal(t, 12345.0, resultMap["another_unknown"])
}

func TestChatCompletionRequest_ComplexUnknownFields(t *testing.T) {
// Test case 3: With complex unknown fields
jsonWithComplexUnknown := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true,
"complex_field": {
"nested1": "value1",
"nested2": {
"deep": 123
}
},
"array_field": [1, 2, 3, 4, 5]
}`

// Unmarshal
var req3 ChatCompletionRequest
err := json.Unmarshal([]byte(jsonWithComplexUnknown), &req3)
assert.NoError(t, err)

// Verify known fields
assert.Equal(t, "gpt-3.5-turbo", req3.Model)
assert.True(t, req3.Stream)

// Verify unknown fields are stored
assert.NotEmpty(t, req3.RawJSON)

// Marshal back and verify all fields are preserved
data3, err := json.Marshal(req3)
assert.NoError(t, err)

// Unmarshal into map to check
var resultMap map[string]interface{}
err = json.Unmarshal(data3, &resultMap)
assert.NoError(t, err)

// Check known fields
assert.Equal(t, "gpt-3.5-turbo", resultMap["model"])
assert.True(t, resultMap["stream"].(bool))

// Check complex unknown fields
complexField, ok := resultMap["complex_field"].(map[string]interface{})
assert.True(t, ok)
assert.Equal(t, "value1", complexField["nested1"])

nested2, ok := complexField["nested2"].(map[string]interface{})
assert.True(t, ok)
assert.Equal(t, 123.0, nested2["deep"])

// Check array unknown field
arrayField, ok := resultMap["array_field"].([]interface{})
assert.True(t, ok)
assert.Len(t, arrayField, 5)
assert.Equal(t, 1.0, arrayField[0])
assert.Equal(t, 5.0, arrayField[4])
}

func TestChatCompletionRequest_EmptyRawJSON(t *testing.T) {
// Test case 4: Empty RawJSON should not cause issues
req4 := &ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("hello"),
},
RawJSON: nil,
}

// Marshal should work fine
data4, err := json.Marshal(req4)
assert.NoError(t, err)

// Unmarshal should work fine
var req4Unmarshaled ChatCompletionRequest
err = json.Unmarshal(data4, &req4Unmarshaled)
assert.NoError(t, err)

// RawJSON should be empty
assert.Empty(t, req4Unmarshaled.RawJSON)
}
Loading