Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions aigateway/handler/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,15 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Input == "" || req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Model and input cannot be empty"})
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Model cannot be empty"})
return
}
if req.Input.OfString.String() == "" &&
len(req.Input.OfArrayOfStrings) == 0 &&
len(req.Input.OfArrayOfTokenArrays) == 0 &&
len(req.Input.OfArrayOfTokens) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Input cannot be empty"})
return
}
modelID := req.Model
Expand Down Expand Up @@ -415,7 +422,9 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {
ImageID: model.ImageID,
})
w := NewResponseWriterWrapperEmbedding(c.Writer, tokenCounter)
tokenCounter.Input(req.Input)
if req.Input.OfString.String() != "" {
tokenCounter.Input(req.Input.OfString.Value)
}

rp.ServeHTTP(w, c.Request, "", host)
go func() {
Expand Down
48 changes: 36 additions & 12 deletions aigateway/handler/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
tester, c, w := setupTest(t)
// Empty Input
embeddingReq := EmbeddingRequest{
Model: "model1:svc1",
Input: "",
EmbeddingNewParams: openai.EmbeddingNewParams{
Model: "model1:svc1",
Input: openai.EmbeddingNewParamsInputUnion{
OfArrayOfStrings: []string{},
},
},
}
body, _ := json.Marshal(embeddingReq)
c.Request.Method = http.MethodPost
Expand All @@ -424,8 +428,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
httpbase.SetCurrentUser(c, "testuser")
httpbase.SetCurrentUserUUID(c, "testuuid")
embeddingReq = EmbeddingRequest{
Model: "",
Input: "test input",
EmbeddingNewParams: openai.EmbeddingNewParams{
Model: "",
Input: openai.EmbeddingNewParamsInputUnion{
OfArrayOfStrings: []string{"test input"},
},
},
}
body, _ = json.Marshal(embeddingReq)
c.Request.Body = io.NopCloser(bytes.NewReader(body))
Expand All @@ -438,8 +446,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
t.Run("model not found", func(t *testing.T) {
tester, c, w := setupTest(t)
embeddingReq := EmbeddingRequest{
Model: "nonexistent:svc",
Input: "test input",
EmbeddingNewParams: openai.EmbeddingNewParams{
Model: "nonexistent:svc",
Input: openai.EmbeddingNewParamsInputUnion{
OfArrayOfStrings: []string{"test input"},
},
},
}
body, _ := json.Marshal(embeddingReq)
c.Request.Method = http.MethodPost
Expand All @@ -455,8 +467,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
t.Run("get model error", func(t *testing.T) {
tester, c, w := setupTest(t)
embeddingReq := EmbeddingRequest{
Model: "model1:svc1",
Input: "test input",
EmbeddingNewParams: openai.EmbeddingNewParams{
Model: "model1:svc1",
Input: openai.EmbeddingNewParamsInputUnion{
OfArrayOfStrings: []string{"test input"},
},
},
}
body, _ := json.Marshal(embeddingReq)
c.Request.Method = http.MethodPost
Expand All @@ -472,8 +488,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
t.Run("model not running", func(t *testing.T) {
tester, c, w := setupTest(t)
embeddingReq := EmbeddingRequest{
Model: "model1:svc1",
Input: "test input",
EmbeddingNewParams: openai.EmbeddingNewParams{
Model: "model1:svc1",
Input: openai.EmbeddingNewParamsInputUnion{
OfArrayOfStrings: []string{"test input"},
},
},
}
body, _ := json.Marshal(embeddingReq)
c.Request.Method = http.MethodPost
Expand Down Expand Up @@ -503,8 +523,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
t.Run("model without svc name", func(t *testing.T) {
tester, c, _ := setupTest(t)
embeddingReq := EmbeddingRequest{
Model: "model1",
Input: "test input",
EmbeddingNewParams: openai.EmbeddingNewParams{
Model: "model1",
Input: openai.EmbeddingNewParamsInputUnion{
OfArrayOfStrings: []string{"test input"},
},
},
}
body, _ := json.Marshal(embeddingReq)
c.Request.Method = http.MethodPost
Expand Down
77 changes: 74 additions & 3 deletions aigateway/handler/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,78 @@ type StreamOptions struct {

// EmbeddingRequest represents an embedding request structure
type EmbeddingRequest struct {
Input string `json:"input"` // Input text content
Model string `json:"model"` // Model name used (e.g., "text-embedding-ada-002")
EncodingFormat string `json:"encoding_format,omitempty"` // Encoding format (e.g., "float")
openai.EmbeddingNewParams
// RawJSON stores all unknown fields during unmarshaling
RawJSON json.RawMessage `json:"-"`
}

func (r *EmbeddingRequest) UnmarshalJSON(data []byte) error {
// Create a temporary struct to hold the known fields
type TempEmbeddingRequest EmbeddingRequest

// First, unmarshal into the temporary struct
var temp TempEmbeddingRequest
if err := json.Unmarshal(data, &temp); err != nil {
return err
}

// Then, unmarshal into a map to get all fields
var allFields map[string]json.RawMessage
if err := json.Unmarshal(data, &allFields); err != nil {
return err
}

// Remove known fields from the map
delete(allFields, "model")
delete(allFields, "input")
delete(allFields, "encoding_format")

// If there are any unknown fields left, marshal them into RawJSON
var rawJSON []byte
var err error
if len(allFields) > 0 {
rawJSON, err = json.Marshal(allFields)
if err != nil {
return err
}
}

// Assign the temporary struct to the original and set RawJSON
*r = EmbeddingRequest(temp)
r.RawJSON = rawJSON
return nil
}

func (r EmbeddingRequest) MarshalJSON() ([]byte, error) {
// First, marshal the known fields
type TempEmbeddingRequest EmbeddingRequest
data, err := json.Marshal(TempEmbeddingRequest(r))
if err != nil {
return nil, err
}

// If there are no raw JSON fields, just return the known fields
if len(r.RawJSON) == 0 {
return data, nil
}

// Parse the known fields back into a map
var knownFields map[string]json.RawMessage
if err := json.Unmarshal(data, &knownFields); err != nil {
return nil, err
}

// Parse the raw JSON fields into a map
var rawFields map[string]json.RawMessage
if err := json.Unmarshal(r.RawJSON, &rawFields); err != nil {
return nil, err
}

// Merge the raw fields into the known fields
for k, v := range rawFields {
knownFields[k] = v
}

// Marshal the merged map back into JSON
return json.Marshal(knownFields)
}
67 changes: 67 additions & 0 deletions aigateway/handler/requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,70 @@ func TestChatCompletionRequest_EmptyRawJSON(t *testing.T) {
// RawJSON should be empty
assert.Empty(t, req4Unmarshaled.RawJSON)
}

func TestEmbeddingRequest_MarshalUnmarshal(t *testing.T) {
// Test case 1: Only known fields
req1 := &EmbeddingRequest{
EmbeddingNewParams: openai.EmbeddingNewParams{
Input: openai.EmbeddingNewParamsInputUnion{
OfArrayOfStrings: []string{"Hello, world!"},
},
Model: "text-embedding-ada-002",
},
}

// Marshal to JSON
data1, err := json.Marshal(req1)
assert.NoError(t, err)

// Unmarshal back
var req1Unmarshaled EmbeddingRequest
err = json.Unmarshal(data1, &req1Unmarshaled)
assert.NoError(t, err)

// Verify fields
assert.Equal(t, req1.Model, req1Unmarshaled.Model)
assert.Equal(t, len(req1.Input.OfArrayOfStrings), len(req1Unmarshaled.Input.OfArrayOfStrings))
assert.Empty(t, req1Unmarshaled.RawJSON)
}

func TestEmbeddingRequest_UnknownFields(t *testing.T) {
// Test case 2: With unknown fields
jsonWithUnknown := `{
"model": "text-embedding-ada-002",
"input": ["Hello, world!"],
"unknown_field": "unknown_value",
"another_unknown": 12345
}`

// Unmarshal
var req2 EmbeddingRequest
err := json.Unmarshal([]byte(jsonWithUnknown), &req2)
assert.NoError(t, err)

// Verify known fields
assert.Equal(t, "text-embedding-ada-002", req2.Model)
assert.Equal(t, 1, len(req2.Input.OfArrayOfStrings))

// Verify unknown fields are stored in RawJSON
assert.NotEmpty(t, req2.RawJSON)

// Marshal back and verify unknown fields are preserved
data2, err := json.Marshal(req2)
assert.NoError(t, err)

// Unmarshal into map to check all fields
var resultMap map[string]interface{}
err = json.Unmarshal(data2, &resultMap)
assert.NoError(t, err)

// Check known fields
assert.Equal(t, "text-embedding-ada-002", resultMap["model"])
inputArray, ok := resultMap["input"].([]interface{})
assert.True(t, ok)
assert.Equal(t, "Hello, world!", inputArray[0])

// Check unknown fields
assert.Equal(t, "unknown_value", resultMap["unknown_field"])
assert.Equal(t, 12345.0, resultMap["another_unknown"])
}
Loading