Skip to content

Commit 57c42ea

Browse files
committed
update
Signed-off-by: yxia216 <[email protected]>
1 parent 699aad6 commit 57c42ea

File tree

6 files changed

+152
-4
lines changed

6 files changed

+152
-4
lines changed

internal/apischema/gcp/gcp.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,13 @@ type GenerateContentRequest struct {
3636
// https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L1057
3737
SafetySettings []*genai.SafetySetting `json:"safetySettings,omitempty"`
3838
}
39+
40+
type EmbedContentRequest struct {
41+
// Contains the multipart content of a message.
42+
//
43+
// https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L858
44+
Contents []genai.Content `json:"contents"`
45+
// Tool details of a tool that the model may use to generate a response.
46+
47+
Config *genai.EmbedContentConfig `json:"config,omitempty"`
48+
}

internal/apischema/openai/openai.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,11 @@ type EmbeddingCompletionRequest struct {
15251525
User *string `json:"user,omitempty"`
15261526
}
15271527

1528+
// GetModel implements ModelName interface
1529+
func (e *EmbeddingCompletionRequest) GetModel() string {
1530+
return e.Model
1531+
}
1532+
15281533
// EmbeddingChatRequest represents a request structure for embeddings API. This is not a standard openai, but just extend the request to have messages/chat like completion requests
15291534
type EmbeddingChatRequest struct {
15301535
// Messages: A list of messages comprising the conversation so far.
@@ -1550,10 +1555,28 @@ type EmbeddingChatRequest struct {
15501555
User *string `json:"user,omitempty"`
15511556
}
15521557

1553-
type EmbedddingRequest interface {
1558+
// GetModel implements ModelProvider interface
1559+
func (e *EmbeddingChatRequest) GetModel() string {
1560+
return e.Model
1561+
}
1562+
1563+
type EmbeddingRequest interface {
15541564
EmbeddingCompletionRequest | EmbeddingChatRequest
15551565
}
15561566

1567+
// ModelName interface for types that can provide a model name
1568+
type ModelName interface {
1569+
GetModel() string
1570+
}
1571+
1572+
// GetModelFromEmbeddingRequest extracts the model name from any EmbeddingRequest type
1573+
func GetModelFromEmbeddingRequest[T EmbeddingRequest](req *T) string {
1574+
if mp, ok := any(*req).(ModelName); ok {
1575+
return mp.GetModel()
1576+
}
1577+
return ""
1578+
}
1579+
15571580
// EmbeddingResponse represents a response from /v1/embeddings.
15581581
// https://platform.openai.com/docs/api-reference/embeddings/object
15591582
type EmbeddingResponse struct {

internal/translator/openai_azureopenai_embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type openAIToAzureOpenAITranslatorV1Embedding struct {
3333
}
3434

3535
// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody].
36-
func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingRequest, onRetry bool) (
36+
func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingCompletionRequest, onRetry bool) (
3737
newHeaders []internalapi.Header, newBody []byte, err error,
3838
) {
3939
modelName := req.Model

internal/translator/openai_embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ type openAIToOpenAITranslatorV1Embedding struct {
3636
}
3737

3838
// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody].
39-
func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, _ *openai.EmbeddingRequest, onRetry bool) (
39+
func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, _ *openai.EmbeddingCompletionRequest, onRetry bool) (
4040
newHeaders []internalapi.Header, newBody []byte, err error,
4141
) {
4242
if o.modelNameOverride != "" {
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Copyright Envoy AI Gateway Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
// The full text of the Apache license is available in the LICENSE file at
4+
// the root of the repo.
5+
6+
package translator
7+
8+
import (
9+
"encoding/json"
10+
"fmt"
11+
"strconv"
12+
13+
"github.com/envoyproxy/ai-gateway/internal/apischema/gcp"
14+
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
15+
"github.com/envoyproxy/ai-gateway/internal/internalapi"
16+
)
17+
18+
// NewEmbeddingOpenAIToAzureOpenAITranslator implements [Factory] for OpenAI to Azure OpenAI translation
19+
// for embeddings.
20+
func NewEmbeddingOpenAIToGCPVertexAITranslator(requestModel internalapi.RequestModel, modelNameOverride internalapi.ModelNameOverride) OpenAIEmbeddingTranslator {
21+
return &openAIToGCPVertexAITranslatorV1Embedding{
22+
apiVersion: apiVersion,
23+
openAIToOpenAITranslatorV1Embedding: openAIToOpenAITranslatorV1Embedding{
24+
modelNameOverride: modelNameOverride,
25+
},
26+
}
27+
}
28+
29+
// openAIToGCPVertexAITranslatorV1Embedding implements [OpenAIEmbeddingTranslator] for /embeddings.
30+
type openAIToGCPVertexAITranslatorV1Embedding[T openai.EmbeddingRequest] struct {
31+
requestModel internalapi.RequestModel
32+
openAIToOpenAITranslatorV1Embedding
33+
}
34+
35+
36+
37+
func InputToGeminiConent(input openai.EmbeddingRequestInput){
38+
switch v := input.Value.(type) {
39+
case string:
40+
41+
return v, "string", nil
42+
case []string:
43+
// Array of text inputs
44+
return v, "string_array", nil
45+
case []int64:
46+
// Array of token IDs
47+
return v, "token_array", nil
48+
case [][]int64:
49+
// Array of token ID arrays
50+
return v, "token_array_batch", nil
51+
default:
52+
return nil, "unknown", fmt.Errorf("unsupported input type: %T", v)
53+
}
54+
55+
56+
}
57+
58+
// openAIToGCPVertexAITranslatorV1Embedding converts an OpenAI EmbeddingRequest to a GCP Gemini GenerateContentRequest.
59+
func openAIEmbeddingCompletionToGeminiMessage(openAIReq *openai.EmbeddingCompletionRequest, requestModel internalapi.RequestModel) (*gcp.EmbedContentRequest, error) {
60+
// Convert OpenAI EmbeddingRequest's input to Gemini Contents
61+
contents, err := InputToGeminiConent(openAIReq.Input, requestModel)
62+
if err != nil {
63+
return nil, err
64+
}
65+
66+
// Convert generation config.
67+
embedConfig,, err := openAIReqToGeminiGenerationConfig(openAIReq, requestModel)
68+
if err != nil {
69+
return nil, fmt.Errorf("error converting generation config: %w", err)
70+
}
71+
72+
gcr := gcp.EmbedContentRequest{
73+
Contents: contents,
74+
Config: embedConfig,
75+
}
76+
77+
return &gcr, nil
78+
}
79+
80+
// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody].
81+
func (o *openAIToGCPVertexAITranslatorV1Embedding[T]) RequestBody(original []byte, req *T, onRetry bool) (
82+
newHeaders []internalapi.Header, newBody []byte, err error,
83+
) {
84+
85+
o.requestModel = openai.GetModelFromEmbeddingRequest(req)
86+
if o.modelNameOverride != "" {
87+
// Use modelName override if set.
88+
o.requestModel = o.modelNameOverride
89+
}
90+
91+
// Choose the correct endpoint based on streaming.
92+
var path string
93+
94+
path = buildGCPModelPathSuffix(gcpModelPublisherGoogle, o.requestModel, gcpMethodGenerateContent)
95+
96+
switch any(*req).(type) {
97+
case openai.EmbeddingCompletionRequest:
98+
gcpReq, err := openAIEmbeddingCompletionToGeminiMessage(openAIReq, o.requestModel)
99+
case openai.EmbeddingChatRequest:
100+
gcpReq, err := openAIEmbeddingChatToGeminiMessage(openAIReq, o.requestModel)
101+
102+
default:
103+
return nil, nil, fmt.Errorf("request body is wrong: %w", err)
104+
}
105+
106+
newBody, err = json.Marshal(gcpReq)
107+
if err != nil {
108+
return nil, nil, fmt.Errorf("error marshaling Gemini request: %w", err)
109+
}
110+
newHeaders = []internalapi.Header{
111+
{pathHeaderName, path},
112+
{contentLengthHeaderName, strconv.Itoa(len(newBody))},
113+
}
114+
return
115+
}

internal/translator/translator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ type (
7878
// OpenAIChatCompletionTranslator translates the OpenAI's /chat/completions endpoint.
7979
OpenAIChatCompletionTranslator = Translator[openai.ChatCompletionRequest, tracing.ChatCompletionSpan]
8080
// OpenAIEmbeddingTranslator translates the OpenAI's /embeddings endpoint.
81-
OpenAIEmbeddingTranslator = Translator[openai.EmbeddingRequest, tracing.EmbeddingsSpan]
81+
OpenAIEmbeddingTranslator = Translator[openai.EmbeddingCompletionRequest, tracing.EmbeddingsSpan]
8282
// OpenAICompletionTranslator translates the OpenAI's /completions endpoint.
8383
OpenAICompletionTranslator = Translator[openai.CompletionRequest, tracing.CompletionSpan]
8484
// CohereRerankTranslator translates the Cohere's /v2/rerank endpoint.

0 commit comments

Comments
 (0)