diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index ef9d2b5e3..f94b0c769 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -346,8 +346,15 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if req.Input == "" || req.Model == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "Model and input cannot be empty"}) + if req.Model == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Model cannot be empty"}) + return + } + if req.Input.OfString.String() == "" && + len(req.Input.OfArrayOfStrings) == 0 && + len(req.Input.OfArrayOfTokenArrays) == 0 && + len(req.Input.OfArrayOfTokens) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Input cannot be empty"}) return } modelID := req.Model @@ -415,7 +422,9 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) { ImageID: model.ImageID, }) w := NewResponseWriterWrapperEmbedding(c.Writer, tokenCounter) - tokenCounter.Input(req.Input) + if req.Input.OfString.String() != "" { + tokenCounter.Input(req.Input.OfString.Value) + } rp.ServeHTTP(w, c.Request, "", host) go func() { diff --git a/aigateway/handler/openai_test.go b/aigateway/handler/openai_test.go index a26611e16..ffdef7986 100644 --- a/aigateway/handler/openai_test.go +++ b/aigateway/handler/openai_test.go @@ -403,8 +403,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { tester, c, w := setupTest(t) // Empty Input embeddingReq := EmbeddingRequest{ - Model: "model1:svc1", - Input: "", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "model1:svc1", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost @@ -424,8 +428,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { httpbase.SetCurrentUser(c, "testuser") httpbase.SetCurrentUserUUID(c, "testuuid") embeddingReq = EmbeddingRequest{ - Model: "", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ = json.Marshal(embeddingReq) c.Request.Body = io.NopCloser(bytes.NewReader(body)) @@ -438,8 +446,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { t.Run("model not found", func(t *testing.T) { tester, c, w := setupTest(t) embeddingReq := EmbeddingRequest{ - Model: "nonexistent:svc", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "nonexistent:svc", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost @@ -455,8 +467,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { t.Run("get model error", func(t *testing.T) { tester, c, w := setupTest(t) embeddingReq := EmbeddingRequest{ - Model: "model1:svc1", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "model1:svc1", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost @@ -472,8 +488,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { t.Run("model not running", func(t *testing.T) { tester, c, w := setupTest(t) embeddingReq := EmbeddingRequest{ - Model: "model1:svc1", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "model1:svc1", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost @@ -503,8 +523,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { t.Run("model without svc name", func(t *testing.T) { tester, c, _ := setupTest(t) embeddingReq := EmbeddingRequest{ - Model: "model1", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "model1", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost diff --git a/aigateway/handler/requests.go b/aigateway/handler/requests.go index 75fc8aef2..dedc072c9 100644 --- a/aigateway/handler/requests.go +++ b/aigateway/handler/requests.go @@ -119,7 +119,78 @@ type StreamOptions struct { // EmbeddingRequest represents an embedding request structure type EmbeddingRequest struct { - Input string `json:"input"` // Input text content - Model string `json:"model"` // Model name used (e.g., "text-embedding-ada-002") - EncodingFormat string `json:"encoding_format,omitempty"` // Encoding format (e.g., "float") + openai.EmbeddingNewParams + // RawJSON stores all unknown fields during unmarshaling + RawJSON json.RawMessage `json:"-"` +} + +func (r *EmbeddingRequest) UnmarshalJSON(data []byte) error { + // Create a temporary struct to hold the known fields + type TempEmbeddingRequest EmbeddingRequest + + // First, unmarshal into the temporary struct + var temp TempEmbeddingRequest + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Then, unmarshal into a map to get all fields + var allFields map[string]json.RawMessage + if err := json.Unmarshal(data, &allFields); err != nil { + return err + } + + // Remove known fields from the map + delete(allFields, "model") + delete(allFields, "input") + delete(allFields, "encoding_format") + + // If there are any unknown fields left, marshal them into RawJSON + var rawJSON []byte + var err error + if len(allFields) > 0 { + rawJSON, err = json.Marshal(allFields) + if err != nil { + return err + } + } + + // Assign the temporary struct to the original and set RawJSON + *r = EmbeddingRequest(temp) + r.RawJSON = rawJSON + return nil +} + +func (r EmbeddingRequest) MarshalJSON() ([]byte, error) { + // First, marshal the known fields + type TempEmbeddingRequest EmbeddingRequest + data, err := json.Marshal(TempEmbeddingRequest(r)) + if err != nil { + return nil, err + } + + // If there are no raw JSON fields, just return the known fields + if len(r.RawJSON) == 0 { + return data, nil + } + + // Parse the known fields back into a map + var knownFields map[string]json.RawMessage + if err := json.Unmarshal(data, &knownFields); err != nil { + return nil, err + } + + // Parse the raw JSON fields into a map + var rawFields map[string]json.RawMessage + if err := json.Unmarshal(r.RawJSON, &rawFields); err != nil { + return nil, err + } + + // Merge the raw fields into the known fields + for k, v := range rawFields { + knownFields[k] = v + } + + // Marshal the merged map back into JSON + return json.Marshal(knownFields) } diff --git a/aigateway/handler/requests_test.go b/aigateway/handler/requests_test.go index 8b6d971cb..7ec2f159c 100644 --- a/aigateway/handler/requests_test.go +++ b/aigateway/handler/requests_test.go @@ -162,3 +162,70 @@ func TestChatCompletionRequest_EmptyRawJSON(t *testing.T) { // RawJSON should be empty assert.Empty(t, req4Unmarshaled.RawJSON) } + +func TestEmbeddingRequest_MarshalUnmarshal(t *testing.T) { + // Test case 1: Only known fields + req1 := &EmbeddingRequest{ + EmbeddingNewParams: openai.EmbeddingNewParams{ + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"Hello, world!"}, + }, + Model: "text-embedding-ada-002", + }, + } + + // Marshal to JSON + data1, err := json.Marshal(req1) + assert.NoError(t, err) + + // Unmarshal back + var req1Unmarshaled EmbeddingRequest + err = json.Unmarshal(data1, &req1Unmarshaled) + assert.NoError(t, err) + + // Verify fields + assert.Equal(t, req1.Model, req1Unmarshaled.Model) + assert.Equal(t, len(req1.Input.OfArrayOfStrings), len(req1Unmarshaled.Input.OfArrayOfStrings)) + assert.Empty(t, req1Unmarshaled.RawJSON) +} + +func TestEmbeddingRequest_UnknownFields(t *testing.T) { + // Test case 2: With unknown fields + jsonWithUnknown := `{ + "model": "text-embedding-ada-002", + "input": ["Hello, world!"], + "unknown_field": "unknown_value", + "another_unknown": 12345 + }` + + // Unmarshal + var req2 EmbeddingRequest + err := json.Unmarshal([]byte(jsonWithUnknown), &req2) + assert.NoError(t, err) + + // Verify known fields + assert.Equal(t, "text-embedding-ada-002", req2.Model) + assert.Equal(t, 1, len(req2.Input.OfArrayOfStrings)) + + // Verify unknown fields are stored in RawJSON + assert.NotEmpty(t, req2.RawJSON) + + // Marshal back and verify unknown fields are preserved + data2, err := json.Marshal(req2) + assert.NoError(t, err) + + // Unmarshal into map to check all fields + var resultMap map[string]interface{} + err = json.Unmarshal(data2, &resultMap) + assert.NoError(t, err) + + // Check known fields + assert.Equal(t, "text-embedding-ada-002", resultMap["model"]) + inputArray, ok := resultMap["input"].([]interface{}) + assert.True(t, ok) + assert.Equal(t, "Hello, world!", inputArray[0]) + + // Check unknown fields + assert.Equal(t, "unknown_value", resultMap["unknown_field"]) + assert.Equal(t, 12345.0, resultMap["another_unknown"]) +}