Skip to content

Commit 71b2e68

Browse files
committed
update
Signed-off-by: yxia216 <[email protected]>
1 parent c7ca20a commit 71b2e68

22 files changed

+840
-965
lines changed

internal/apischema/gcp/gcp.go

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
package gcp
77

8-
import "google.golang.org/genai"
8+
import (
9+
"google.golang.org/genai"
10+
11+
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
12+
)
913

1014
type GenerateContentRequest struct {
1115
// Contains the multipart content of a message.
@@ -37,14 +41,47 @@ type GenerateContentRequest struct {
3741
SafetySettings []*genai.SafetySetting `json:"safetySettings,omitempty"`
3842
}
3943

40-
type EmbedContentRequest struct {
41-
// Content to be embedded. Only text content is supported for embeddings.
42-
Content *genai.Content `json:"content"`
44+
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#syntax
45+
type Instance struct {
46+
// The text that you want to generate embeddings for.
47+
Content string `json:"content"`
48+
49+
// Used to convey intended downstream application to help the model produce better embeddings. If left blank, the default used is RETRIEVAL_QUERY.
50+
// For more information about task types, see https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types
51+
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#task_type
52+
TaskType openai.EmbeddingTaskType `json:"task_type,omitempty"`
53+
54+
// Used to help the model produce better embeddings. Only valid with task_type=RETRIEVAL_DOCUMENT.
55+
Title string `json:"title,omitempty"`
56+
}
57+
58+
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#parameter-list
59+
type Parameters struct {
60+
// When set to true, input text will be truncated. When set to false, an error is returned if the input text is longer than the maximum length supported by the model. Defaults to true.
61+
AutoTruncate bool `json:"auto_truncate,omitempty"`
62+
63+
// Used to specify output embedding size. If set, output embeddings will be truncated to the size specified.
64+
OutputDimensionality int `json:"out_dimensionality,omitempty"`
65+
}
66+
67+
// https://github.com/googleapis/python-aiplatform/blob/30e41d01f3fd0ef08da6ad6eb7f83df34476105e/google/cloud/aiplatform_v1/types/prediction_service.py#L63
68+
type PredictRequest struct {
69+
// A list of instances
70+
//
71+
Instances []*Instance `json:"instances"`
4372

4473
// Optional configuration for the embedding request.
4574
// Uses the official genai library configuration structure.
46-
Config *genai.EmbedContentConfig `json:"config,omitempty"`
75+
Parameters Parameters `json:"parameters,omitempty"`
4776
}
4877

49-
// Note: We now use genai.EmbedContentResponse directly instead of defining our own.
50-
// This provides better compatibility and includes metadata like token usage.
78+
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#response_body
79+
type Prediction struct {
80+
// The result generated from input text.
81+
Embeddings genai.ContentEmbedding `json:"embeddings"`
82+
}
83+
84+
// https://github.com/googleapis/python-aiplatform/blob/30e41d01f3fd0ef08da6ad6eb7f83df34476105e/google/cloud/aiplatform_v1/types/prediction_service.py#L117
85+
type PredictResponse struct {
86+
Predictions []*Prediction `json:"predictions"`
87+
}

internal/apischema/openai/openai.go

Lines changed: 33 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ func (c ContentUnion) MarshalJSON() ([]byte, error) {
310310
return json.Marshal(c.Value)
311311
}
312312

313+
// EmbeddingInputItem represents a single embedding input with optional metadata
314+
type EmbeddingInputItem struct {
315+
Content string `json:"content"` // The actual text content
316+
TaskType EmbeddingTaskType `json:"task_type,omitempty"` // Optional task type
317+
Title string `json:"title,omitempty"` // Optional title
318+
}
319+
313320
// EmbeddingRequestInput is the EmbeddingRequest.Input type.
314321
type EmbeddingRequestInput struct {
315322
Value any
@@ -1498,8 +1505,8 @@ type Model struct {
14981505
OwnedBy string `json:"owned_by"`
14991506
}
15001507

1501-
// EmbeddingCompletionRequest represents a request structure for embeddings API.
1502-
type EmbeddingCompletionRequest struct {
1508+
// EmbeddingRequest represents a request structure for embeddings API.
1509+
type EmbeddingRequest struct {
15031510
// Input: Input text to embed, encoded as a string or array of tokens.
15041511
// To embed multiple inputs in a single request, pass an array of strings or array of token arrays.
15051512
// The input must not exceed the max input tokens for the model (8192 tokens for text-embedding-ada-002),
@@ -1524,101 +1531,33 @@ type EmbeddingCompletionRequest struct {
15241531
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-user
15251532
User *string `json:"user,omitempty"`
15261533

1527-
// GCPVertexAIEmbeddingVendorFields configures the GCP VertexAI specific fields during schema translation.
1534+
// GCPVertexAIEmbeddingVendorFields configures the GCP VertexAI specific fields for embedding during schema translation.
15281535
*GCPVertexAIEmbeddingVendorFields `json:",inline,omitempty"`
15291536
}
15301537

1531-
// GetModel implements ModelName interface
1532-
func (e *EmbeddingCompletionRequest) GetModel() string {
1533-
return e.Model
1534-
}
1535-
1536-
// EmbeddingChatRequest represents a request structure for embeddings API. This is not a standard openai, but just extend the request to have messages/chat like completion requests
1537-
type EmbeddingChatRequest struct {
1538-
// Messages: A list of messages comprising the conversation so far.
1539-
// Depending on the model you use, different message types (modalities) are supported,
1540-
// like text, images, and audio.
1541-
Messages []ChatCompletionMessageParamUnion `json:"messages"`
1542-
1543-
// Model: ID of the model to use.
1544-
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-model
1545-
Model string `json:"model"`
1546-
1547-
// EncodingFormat: The format to return the embeddings in. Can be either float or base64.
1548-
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-encoding_format
1549-
EncodingFormat *string `json:"encoding_format,omitempty"` //nolint:tagliatelle //follow openai api
1550-
1551-
// Dimensions: The number of dimensions the resulting output embeddings should have.
1552-
// Only supported in text-embedding-3 and later models.
1553-
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions
1554-
Dimensions *int `json:"dimensions,omitempty"`
1555-
1556-
// User: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
1557-
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-user
1558-
User *string `json:"user,omitempty"`
1559-
1560-
// GCPVertexAIEmbeddingVendorFields configures the GCP VertexAI specific fields during schema translation.
1561-
*GCPVertexAIEmbeddingVendorFields `json:",inline,omitempty"`
1562-
}
1538+
type EmbeddingTaskType string
15631539

1564-
// GetModel implements ModelProvider interface
1565-
func (e *EmbeddingChatRequest) GetModel() string {
1566-
return e.Model
1567-
}
1568-
1569-
// EmbeddingRequest is a union type that can handle both EmbeddingCompletionRequest and EmbeddingChatRequest.
1570-
type EmbeddingRequest struct {
1571-
OfCompletion *EmbeddingCompletionRequest `json:",omitzero,inline"`
1572-
OfChat *EmbeddingChatRequest `json:",omitzero,inline"`
1573-
}
1574-
1575-
// UnmarshalJSON implements json.Unmarshaler to handle both EmbeddingCompletionRequest and EmbeddingChatRequest.
1576-
func (e *EmbeddingRequest) UnmarshalJSON(data []byte) error {
1577-
// Check for Messages field to distinguish EmbeddingChatRequest
1578-
messagesResult := gjson.GetBytes(data, "messages")
1579-
if messagesResult.Exists() {
1580-
var chatReq EmbeddingChatRequest
1581-
if err := json.Unmarshal(data, &chatReq); err != nil {
1582-
return err
1583-
}
1584-
e.OfChat = &chatReq
1585-
return nil
1586-
}
1587-
1588-
// Check for Input field to distinguish EmbeddingCompletionRequest
1589-
inputResult := gjson.GetBytes(data, "input")
1590-
if inputResult.Exists() {
1591-
var completionReq EmbeddingCompletionRequest
1592-
if err := json.Unmarshal(data, &completionReq); err != nil {
1593-
return err
1594-
}
1595-
e.OfCompletion = &completionReq
1596-
return nil
1597-
}
1540+
const (
1541+
EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY"
1542+
EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT"
1543+
EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY"
1544+
EmbeddingTaskTypeClassification EmbeddingTaskType = "CLASSIFICATION"
1545+
EmbeddingTaskTypeClustering EmbeddingTaskType = "CLUSTERING"
1546+
EmbeddingTaskTypeQuestionAnswering EmbeddingTaskType = "QUESTION_ANSWERING"
1547+
EmbeddingTaskTypeFactVerification EmbeddingTaskType = "FACT_VERIFICATION"
1548+
EmbeddingTaskTypeCodeRetrievalQuery EmbeddingTaskType = "CODE_RETRIEVAL_QUERY"
1549+
)
15981550

1599-
return errors.New("embedding request must have either 'input' field (EmbeddingCompletionRequest) or 'messages' field (EmbeddingChatRequest)")
1600-
}
1551+
// GCPVertexAIEmbeddingVendorFields contains GCP Vertex AI (Gemini) vendor-specific fields for embeddings.
1552+
type GCPVertexAIEmbeddingVendorFields struct {
1553+
// When set to true, input text will be truncated. When set to false, an error is returned if the input text is longer than the maximum length supported by the model. Defaults to true.
1554+
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#parameter-list
16011555

1602-
// MarshalJSON implements json.Marshaler.
1603-
func (e EmbeddingRequest) MarshalJSON() ([]byte, error) {
1604-
if e.OfCompletion != nil {
1605-
return json.Marshal(e.OfCompletion)
1606-
}
1607-
if e.OfChat != nil {
1608-
return json.Marshal(e.OfChat)
1609-
}
1610-
return nil, errors.New("no embedding request to marshal")
1611-
}
1556+
AutoTruncate bool `json:"auto_truncate,omitempty"`
16121557

1613-
// GetModelFromEmbeddingRequest extracts the model name from any EmbeddingRequest type
1614-
func GetModelFromEmbeddingRequest(req *EmbeddingRequest) string {
1615-
if req.OfCompletion != nil {
1616-
return req.OfCompletion.GetModel()
1617-
}
1618-
if req.OfChat != nil {
1619-
return req.OfChat.GetModel()
1620-
}
1621-
return ""
1558+
// This is global task_type set, which is convenient for users. If left blank, the default used is RETRIEVAL_QUERY.
1559+
// For more information about task types, see https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types
1560+
TaskType EmbeddingTaskType `json:"task_type,omitempty"`
16221561
}
16231562

16241563
// EmbeddingResponse represents a response from /v1/embeddings.
@@ -1653,6 +1592,10 @@ type Embedding struct {
16531592

16541593
// Index: The index of the embedding in the list of embeddings.
16551594
Index int `json:"index"`
1595+
1596+
// If the input text was truncated due to having a length longer than the allowed maximum input.
1597+
// https://github.com/googleapis/go-genai/blob/cb486e101dc66794d52125dd22ff43ff4c0e76a6/types.go#L2807
1598+
Truncated bool `json:"truncated,omitempty"`
16561599
}
16571600

16581601
// EmbeddingUnion is a union type that can handle both []float64 and string formats.
@@ -1694,13 +1637,6 @@ type EmbeddingUsage struct {
16941637
TotalTokens int `json:"total_tokens"` //nolint:tagliatelle //follow openai api
16951638
}
16961639

1697-
// GCPVertexAIEmbeddingVendorFields contains GCP Vertex AI (Gemini) vendor-specific fields for embedding requests.
1698-
type GCPVertexAIEmbeddingVendorFields struct {
1699-
// Type of task for which the embedding will be used.
1700-
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types#supported_task_types
1701-
TaskType string `json:"task_type,omitempty"`
1702-
}
1703-
17041640
// JSONUNIXTime is a helper type to marshal/unmarshal time.Time UNIX timestamps.
17051641
type JSONUNIXTime time.Time
17061642

internal/apischema/openai/union.go

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ func unmarshalJSONNestedUnion(typ string, data []byte) (interface{}, error) {
2323
case '"':
2424
return unquoteOrUnmarshalJSONString(typ, data)
2525

26+
case '{':
27+
// Single object with content/task_type/title
28+
var item EmbeddingInputItem
29+
err = json.Unmarshal(data, &item)
30+
if err != nil {
31+
return nil, fmt.Errorf("cannot unmarshal %s as EmbeddingInputItem: %w", typ, err)
32+
}
33+
// Validate that the content field is not empty
34+
if item.Content == "" {
35+
return nil, fmt.Errorf("invalid %s type (must be string, object, or array)", typ)
36+
}
37+
return item, nil
38+
2639
case '[':
2740
// Array: skip to first element
2841
idx++
@@ -38,13 +51,31 @@ func unmarshalJSONNestedUnion(typ string, data []byte) (interface{}, error) {
3851
// Determine element type
3952
switch data[idx] {
4053
case '"':
54+
// Check if this is a mixed array (strings and objects)
55+
if isMixedArray(data) {
56+
return unmarshalMixedArray(typ, data)
57+
}
4158
// []string
4259
var strs []string
4360
if err := json.Unmarshal(data, &strs); err != nil {
4461
return nil, fmt.Errorf("cannot unmarshal %s as []string: %w", typ, err)
4562
}
4663
return strs, nil
4764

65+
case '{':
66+
// []EmbeddingInputItem
67+
var items []EmbeddingInputItem
68+
if err := json.Unmarshal(data, &items); err != nil {
69+
return nil, fmt.Errorf("cannot unmarshal %s as []EmbeddingInputItem: %w", typ, err)
70+
}
71+
// Validate that all items have non-empty content
72+
for _, item := range items {
73+
if item.Content == "" {
74+
return nil, fmt.Errorf("invalid %s array element", typ)
75+
}
76+
}
77+
return items, nil
78+
4879
case '[':
4980
// [][]int64
5081
var intArrays [][]int64
@@ -60,7 +91,7 @@ func unmarshalJSONNestedUnion(typ string, data []byte) (interface{}, error) {
6091
}
6192

6293
default:
63-
return nil, fmt.Errorf("invalid %s type (must be string or array)", typ)
94+
return nil, fmt.Errorf("invalid %s type (must be string, object, or array)", typ)
6495
}
6596
}
6697

@@ -101,3 +132,86 @@ func unquoteOrUnmarshalJSONString(typ string, data []byte) (string, error) {
101132
}
102133
return str, nil
103134
}
135+
136+
// isMixedArray checks if the array contains both strings and objects
137+
func isMixedArray(data []byte) bool {
138+
var arr []json.RawMessage
139+
if err := json.Unmarshal(data, &arr); err != nil {
140+
return false
141+
}
142+
143+
hasString := false
144+
hasObject := false
145+
146+
for _, item := range arr {
147+
trimmed := item
148+
// Skip leading whitespace
149+
idx := 0
150+
for idx < len(trimmed) && (trimmed[idx] == ' ' || trimmed[idx] == '\t' || trimmed[idx] == '\n' || trimmed[idx] == '\r') {
151+
idx++
152+
}
153+
if idx >= len(trimmed) {
154+
continue
155+
}
156+
157+
switch trimmed[idx] {
158+
case '"':
159+
hasString = true
160+
case '{':
161+
hasObject = true
162+
}
163+
164+
// If we have both types, it's a mixed array
165+
if hasString && hasObject {
166+
return true
167+
}
168+
}
169+
170+
return false
171+
}
172+
173+
// unmarshalMixedArray handles arrays with both strings and EmbeddingInputItem objects
174+
func unmarshalMixedArray(typ string, data []byte) (interface{}, error) {
175+
var arr []json.RawMessage
176+
if err := json.Unmarshal(data, &arr); err != nil {
177+
return nil, fmt.Errorf("cannot unmarshal %s as mixed array: %w", typ, err)
178+
}
179+
180+
result := make([]interface{}, len(arr))
181+
182+
for i, item := range arr {
183+
// Skip leading whitespace
184+
idx := 0
185+
for idx < len(item) && (item[idx] == ' ' || item[idx] == '\t' || item[idx] == '\n' || item[idx] == '\r') {
186+
idx++
187+
}
188+
if idx >= len(item) {
189+
return nil, fmt.Errorf("empty element in mixed %s array", typ)
190+
}
191+
192+
switch item[idx] {
193+
case '"':
194+
// String element
195+
var str string
196+
if err := json.Unmarshal(item, &str); err != nil {
197+
return nil, fmt.Errorf("cannot unmarshal string element in mixed %s array: %w", typ, err)
198+
}
199+
result[i] = str
200+
case '{':
201+
// Object element
202+
var embeddingItem EmbeddingInputItem
203+
if err := json.Unmarshal(item, &embeddingItem); err != nil {
204+
return nil, fmt.Errorf("cannot unmarshal object element in mixed %s array: %w", typ, err)
205+
}
206+
// Validate that the content field is not empty
207+
if embeddingItem.Content == "" {
208+
return nil, fmt.Errorf("invalid element type in mixed %s array", typ)
209+
}
210+
result[i] = embeddingItem
211+
default:
212+
return nil, fmt.Errorf("invalid element type in mixed %s array", typ)
213+
}
214+
}
215+
216+
return result, nil
217+
}

0 commit comments

Comments
 (0)