From 71b2589dc741e7200d16ebc12d481df93400828b Mon Sep 17 00:00:00 2001 From: yxia216 Date: Mon, 24 Nov 2025 21:29:04 -0500 Subject: [PATCH 1/2] init Signed-off-by: yxia216 --- internal/apischema/openai/openai.go | 33 +++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index c2afcb7f0..3a2abac16 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -1493,8 +1493,8 @@ type Model struct { OwnedBy string `json:"owned_by"` } -// EmbeddingRequest represents a request structure for embeddings API. -type EmbeddingRequest struct { +// EmbeddingCompletionRequest represents a request structure for embeddings API. +type EmbeddingCompletionRequest struct { // Input: Input text to embed, encoded as a string or array of tokens. // To embed multiple inputs in a single request, pass an array of strings or array of token arrays. // The input must not exceed the max input tokens for the model (8192 tokens for text-embedding-ada-002), @@ -1520,6 +1520,35 @@ type EmbeddingRequest struct { User *string `json:"user,omitempty"` } +// EmbeddingChatRequest represents a request structure for embeddings API. This is not a standard openai, but just extend the request to have messages/chat like completion requests +type EmbeddingChatRequest struct { + // Messages: A list of messages comprising the conversation so far. + // Depending on the model you use, different message types (modalities) are supported, + // like text, images, and audio. + Messages []ChatCompletionMessageParamUnion `json:"messages"` + + // Model: ID of the model to use. + // Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-model + Model string `json:"model"` + + // EncodingFormat: The format to return the embeddings in. Can be either float or base64. + // Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-encoding_format + EncodingFormat *string `json:"encoding_format,omitempty"` //nolint:tagliatelle //follow openai api + + // Dimensions: The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + // Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions + Dimensions *int `json:"dimensions,omitempty"` + + // User: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + // Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-user + User *string `json:"user,omitempty"` +} + +type EmbedddingRequest interface { + EmbeddingCompletionRequest | EmbeddingChatRequest +} + // EmbeddingResponse represents a response from /v1/embeddings. // https://platform.openai.com/docs/api-reference/embeddings/object type EmbeddingResponse struct { From 06700b5788f4c639de7f9f3e93e6be1a8bf2cc41 Mon Sep 17 00:00:00 2001 From: yxia216 Date: Mon, 1 Dec 2025 23:56:13 -0500 Subject: [PATCH 2/2] update Signed-off-by: yxia216 --- internal/apischema/gcp/gcp.go | 10 ++ internal/apischema/openai/openai.go | 25 +++- internal/extproc/embeddings_processor.go | 71 +++++++---- internal/tracing/api/api.go | 6 +- .../openai_azureopenai_embeddings.go | 2 +- internal/translator/openai_embeddings.go | 2 +- .../openai_gcpvertexai_embeddings.go | 115 ++++++++++++++++++ internal/translator/translator.go | 2 +- 8 files changed, 202 insertions(+), 31 deletions(-) create mode 100644 internal/translator/openai_gcpvertexai_embeddings.go diff --git a/internal/apischema/gcp/gcp.go b/internal/apischema/gcp/gcp.go index b728d3a90..5ce0297e2 100644 --- a/internal/apischema/gcp/gcp.go +++ b/internal/apischema/gcp/gcp.go @@ -36,3 +36,13 @@ type GenerateContentRequest struct { // https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L1057 SafetySettings []*genai.SafetySetting `json:"safetySettings,omitempty"` } + +type EmbedContentRequest struct { + // Contains the multipart content of a message. + // + // https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L858 + Contents []genai.Content `json:"contents"` + // Tool details of a tool that the model may use to generate a response. + + Config *genai.EmbedContentConfig `json:"config,omitempty"` +} diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index 3a2abac16..d53955452 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -1520,6 +1520,11 @@ type EmbeddingCompletionRequest struct { User *string `json:"user,omitempty"` } +// GetModel implements ModelName interface +func (e *EmbeddingCompletionRequest) GetModel() string { + return e.Model +} + // EmbeddingChatRequest represents a request structure for embeddings API. This is not a standard openai, but just extend the request to have messages/chat like completion requests type EmbeddingChatRequest struct { // Messages: A list of messages comprising the conversation so far. @@ -1545,10 +1550,28 @@ type EmbeddingChatRequest struct { User *string `json:"user,omitempty"` } -type EmbedddingRequest interface { +// GetModel implements ModelProvider interface +func (e *EmbeddingChatRequest) GetModel() string { + return e.Model +} + +type EmbeddingRequest interface { EmbeddingCompletionRequest | EmbeddingChatRequest } +// ModelName interface for types that can provide a model name +type ModelName interface { + GetModel() string +} + +// GetModelFromEmbeddingRequest extracts the model name from any EmbeddingRequest type +func GetModelFromEmbeddingRequest[T EmbeddingRequest](req *T) string { + if mp, ok := any(*req).(ModelName); ok { + return mp.GetModel() + } + return "" +} + // EmbeddingResponse represents a response from /v1/embeddings. // https://platform.openai.com/docs/api-reference/embeddings/object type EmbeddingResponse struct { diff --git a/internal/extproc/embeddings_processor.go b/internal/extproc/embeddings_processor.go index 2af6b63a1..81bfa3bf9 100644 --- a/internal/extproc/embeddings_processor.go +++ b/internal/extproc/embeddings_processor.go @@ -32,14 +32,14 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory { return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { logger = logger.With("processor", "embeddings", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter)) if !isUpstreamFilter { - return &embeddingsProcessorRouterFilter{ + return &embeddingsProcessorRouterFilter[openai.EmbeddingCompletionRequest]{ config: config, tracer: tracing.EmbeddingsTracer(), requestHeaders: requestHeaders, logger: logger, }, nil } - return &embeddingsProcessorUpstreamFilter{ + return &embeddingsProcessorUpstreamFilter[openai.EmbeddingCompletionRequest]{ config: config, requestHeaders: requestHeaders, logger: logger, @@ -51,7 +51,7 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory { // embeddingsProcessorRouterFilter implements [Processor] for the `/v1/embeddings` endpoint. // // This is primarily used to select the route for the request based on the model name. -type embeddingsProcessorRouterFilter struct { +type embeddingsProcessorRouterFilter[T openai.EmbeddingRequest] struct { passThroughProcessor // upstreamFilter is the upstream filter that is used to process the request at the upstream filter. // This will be updated when the request is retried. @@ -67,7 +67,7 @@ type embeddingsProcessorRouterFilter struct { // originalRequestBody is the original request body that is passed to the upstream filter. // This is used to perform the transformation of the request body on the original input // when the request is retried. - originalRequestBody *openai.EmbeddingRequest + originalRequestBody *T originalRequestBodyRaw []byte // tracer is the tracer used for requests. tracer tracing.EmbeddingsTracer @@ -79,7 +79,7 @@ type embeddingsProcessorRouterFilter struct { } // ProcessResponseHeaders implements [Processor.ProcessResponseHeaders]. -func (e *embeddingsProcessorRouterFilter) ProcessResponseHeaders(ctx context.Context, headerMap *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) { +func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseHeaders(ctx context.Context, headerMap *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) { // If the request failed to route and/or immediate response was returned before the upstream filter was set, // e.upstreamFilter can be nil. if e.upstreamFilter != nil { // See the comment on the "upstreamFilter" field. @@ -89,7 +89,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessResponseHeaders(ctx context.Con } // ProcessResponseBody implements [Processor.ProcessResponseBody]. -func (e *embeddingsProcessorRouterFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) { +func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) { // If the request failed to route and/or immediate response was returned before the upstream filter was set, // e.upstreamFilter can be nil. if e.upstreamFilter != nil { // See the comment on the "upstreamFilter" field. @@ -99,8 +99,8 @@ func (e *embeddingsProcessorRouterFilter) ProcessResponseBody(ctx context.Contex } // ProcessRequestBody implements [Processor.ProcessRequestBody]. -func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) { - originalModel, body, err := parseOpenAIEmbeddingBody(rawBody) +func (e *embeddingsProcessorRouterFilter[T]) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) { + originalModel, body, err := parseOpenAIEmbeddingBody[T](rawBody) if err != nil { return nil, fmt.Errorf("failed to parse request body: %w", err) } @@ -125,7 +125,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context ctx, e.requestHeaders, &headerMutationCarrier{m: headerMutation}, - body, + convertToEmbeddingCompletionRequest(body), rawBody.Body, ) @@ -144,7 +144,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context // embeddingsProcessorUpstreamFilter implements [Processor] for the `/v1/embeddings` endpoint at the upstream filter. // // This is created per retry and handles the translation as well as the authentication of the request. -type embeddingsProcessorUpstreamFilter struct { +type embeddingsProcessorUpstreamFilter[T openai.EmbeddingRequest] struct { logger *slog.Logger config *filterapi.RuntimeConfig requestHeaders map[string]string @@ -156,7 +156,7 @@ type embeddingsProcessorUpstreamFilter struct { headerMutator *headermutator.HeaderMutator bodyMutator *bodymutator.BodyMutator originalRequestBodyRaw []byte - originalRequestBody *openai.EmbeddingRequest + originalRequestBody *T translator translator.OpenAIEmbeddingTranslator // onRetry is true if this is a retry request at the upstream filter. onRetry bool @@ -169,12 +169,14 @@ type embeddingsProcessorUpstreamFilter struct { } // selectTranslator selects the translator based on the output schema. -func (e *embeddingsProcessorUpstreamFilter) selectTranslator(out filterapi.VersionedAPISchema) error { +func (e *embeddingsProcessorUpstreamFilter[T]) selectTranslator(out filterapi.VersionedAPISchema) error { switch out.Name { case filterapi.APISchemaOpenAI: e.translator = translator.NewEmbeddingOpenAIToOpenAITranslator(out.Version, e.modelNameOverride) case filterapi.APISchemaAzureOpenAI: e.translator = translator.NewEmbeddingOpenAIToAzureOpenAITranslator(out.Version, e.modelNameOverride) + case filterapi.APISchemaGCPVertexAI: + e.translator = translator.NewEmbeddingOpenAIToAzureOpenAITranslator(out.Version, e.modelNameOverride) default: return fmt.Errorf("unsupported API schema: backend=%s", out) } @@ -187,7 +189,7 @@ func (e *embeddingsProcessorUpstreamFilter) selectTranslator(out filterapi.Versi // So, we simply do the translation and upstream auth at this stage, and send them back to Envoy // with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again // to the extproc. -func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { +func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { defer func() { if err != nil { e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders) @@ -197,12 +199,12 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co // Start tracking metrics for this request. e.metrics.StartRequest(e.requestHeaders) // Set the original model from the request body before any overrides - e.metrics.SetOriginalModel(e.originalRequestBody.Model) + e.metrics.SetOriginalModel(openai.GetModelFromEmbeddingRequest(e.originalRequestBody)) // Set the request model for metrics from the original model or override if applied. - reqModel := cmp.Or(e.requestHeaders[internalapi.ModelNameHeaderKeyDefault], e.originalRequestBody.Model) + reqModel := cmp.Or(e.requestHeaders[internalapi.ModelNameHeaderKeyDefault], openai.GetModelFromEmbeddingRequest(e.originalRequestBody)) e.metrics.SetRequestModel(reqModel) - newHeaders, newBody, err := e.translator.RequestBody(e.originalRequestBodyRaw, e.originalRequestBody, e.onRetry) + newHeaders, newBody, err := e.translator.RequestBody(e.originalRequestBodyRaw, convertToEmbeddingCompletionRequest(e.originalRequestBody), e.onRetry) if err != nil { return nil, fmt.Errorf("failed to transform request: %w", err) } @@ -265,12 +267,12 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co } // ProcessRequestBody implements [Processor.ProcessRequestBody]. -func (e *embeddingsProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { +func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { panic("BUG: ProcessRequestBody should not be called in the upstream filter") } // ProcessResponseHeaders implements [Processor.ProcessResponseHeaders]. -func (e *embeddingsProcessorUpstreamFilter) ProcessResponseHeaders(ctx context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { +func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseHeaders(ctx context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { defer func() { if err != nil { e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders) @@ -294,7 +296,7 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseHeaders(ctx context.C } // ProcessResponseBody implements [Processor.ProcessResponseBody]. -func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { +func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { recordRequestCompletionErr := false defer func() { if err != nil || recordRequestCompletionErr { @@ -383,13 +385,13 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Cont } // SetBackend implements [Processor.SetBackend]. -func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) { +func (e *embeddingsProcessorUpstreamFilter[T]) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) { defer func() { if err != nil { e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders) } }() - rp, ok := routeProcessor.(*embeddingsProcessorRouterFilter) + rp, ok := routeProcessor.(*embeddingsProcessorRouterFilter[T]) if !ok { panic("BUG: expected routeProcessor to be of type *embeddingsProcessorRouterFilter") } @@ -417,10 +419,31 @@ func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *f return } -func parseOpenAIEmbeddingBody(body *extprocv3.HttpBody) (modelName string, rb *openai.EmbeddingRequest, err error) { - var openAIReq openai.EmbeddingRequest +// convertToEmbeddingCompletionRequest converts any EmbeddingRequest to EmbeddingCompletionRequest for compatibility +func convertToEmbeddingCompletionRequest[T openai.EmbeddingRequest](req *T) *openai.EmbeddingCompletionRequest { + switch r := any(*req).(type) { + case openai.EmbeddingCompletionRequest: + return &r + case openai.EmbeddingChatRequest: + // Convert EmbeddingChatRequest to EmbeddingCompletionRequest by flattening messages to input + // This is a simplified conversion - in practice you might need more sophisticated logic + return &openai.EmbeddingCompletionRequest{ + Model: r.Model, + Input: openai.EmbeddingRequestInput{Value: "converted_from_chat"}, // Simplified + EncodingFormat: r.EncodingFormat, + Dimensions: r.Dimensions, + User: r.User, + } + default: + return &openai.EmbeddingCompletionRequest{} + } +} + +func parseOpenAIEmbeddingBody[T openai.EmbeddingRequest](body *extprocv3.HttpBody) (modelName string, rb *T, err error) { + var openAIReq T if err := json.Unmarshal(body.Body, &openAIReq); err != nil { return "", nil, fmt.Errorf("failed to unmarshal body: %w", err) } - return openAIReq.Model, &openAIReq, nil + + return openai.GetModelFromEmbeddingRequest(&openAIReq), &openAIReq, nil } diff --git a/internal/tracing/api/api.go b/internal/tracing/api/api.go index ed182a398..36b9a8a18 100644 --- a/internal/tracing/api/api.go +++ b/internal/tracing/api/api.go @@ -56,7 +56,7 @@ type ( // CompletionTracer creates spans for OpenAI completion requests. CompletionTracer = RequestTracer[openai.CompletionRequest, CompletionSpan] // EmbeddingsTracer creates spans for OpenAI embeddings requests. - EmbeddingsTracer = RequestTracer[openai.EmbeddingRequest, EmbeddingsSpan] + EmbeddingsTracer = RequestTracer[openai.EmbeddingCompletionRequest, EmbeddingsSpan] // ImageGenerationTracer creates spans for OpenAI image generation requests. ImageGenerationTracer = RequestTracer[openaisdk.ImageGenerateParams, ImageGenerationSpan] // RerankTracer creates spans for rerank requests. @@ -116,7 +116,7 @@ type ( // ImageGenerationRecorder records attributes to a span according to a semantic convention. ImageGenerationRecorder = SpanRecorder[openaisdk.ImageGenerateParams, struct{}, openaisdk.ImagesResponse] // EmbeddingsRecorder records attributes to a span according to a semantic convention. - EmbeddingsRecorder = SpanRecorder[openai.EmbeddingRequest, struct{}, openai.EmbeddingResponse] + EmbeddingsRecorder = SpanRecorder[openai.EmbeddingCompletionRequest, struct{}, openai.EmbeddingResponse] // RerankRecorder records attributes to a span according to a semantic convention. RerankRecorder = SpanRecorder[cohere.RerankV2Request, struct{}, cohere.RerankV2Response] ) @@ -146,7 +146,7 @@ func (NoopTracing) CompletionTracer() CompletionTracer { // EmbeddingsTracer implements Tracing.EmbeddingsTracer. func (NoopTracing) EmbeddingsTracer() EmbeddingsTracer { - return NoopTracer[openai.EmbeddingRequest, EmbeddingsSpan]{} + return NoopTracer[openai.EmbeddingCompletionRequest, EmbeddingsSpan]{} } // ImageGenerationTracer implements Tracing.ImageGenerationTracer. diff --git a/internal/translator/openai_azureopenai_embeddings.go b/internal/translator/openai_azureopenai_embeddings.go index 377b51ae3..70b51695e 100644 --- a/internal/translator/openai_azureopenai_embeddings.go +++ b/internal/translator/openai_azureopenai_embeddings.go @@ -33,7 +33,7 @@ type openAIToAzureOpenAITranslatorV1Embedding struct { } // RequestBody implements [OpenAIEmbeddingTranslator.RequestBody]. -func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingRequest, onRetry bool) ( +func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingCompletionRequest, onRetry bool) ( newHeaders []internalapi.Header, newBody []byte, err error, ) { modelName := req.Model diff --git a/internal/translator/openai_embeddings.go b/internal/translator/openai_embeddings.go index e24c15608..5e6682204 100644 --- a/internal/translator/openai_embeddings.go +++ b/internal/translator/openai_embeddings.go @@ -35,7 +35,7 @@ type openAIToOpenAITranslatorV1Embedding struct { } // RequestBody implements [OpenAIEmbeddingTranslator.RequestBody]. -func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, _ *openai.EmbeddingRequest, onRetry bool) ( +func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, _ *openai.EmbeddingCompletionRequest, onRetry bool) ( newHeaders []internalapi.Header, newBody []byte, err error, ) { if o.modelNameOverride != "" { diff --git a/internal/translator/openai_gcpvertexai_embeddings.go b/internal/translator/openai_gcpvertexai_embeddings.go new file mode 100644 index 000000000..09523c74a --- /dev/null +++ b/internal/translator/openai_gcpvertexai_embeddings.go @@ -0,0 +1,115 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "encoding/json" + "fmt" + "strconv" + + "github.com/envoyproxy/ai-gateway/internal/apischema/gcp" + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/internalapi" +) + +// NewEmbeddingOpenAIToAzureOpenAITranslator implements [Factory] for OpenAI to Azure OpenAI translation +// for embeddings. +func NewEmbeddingOpenAIToGCPVertexAITranslator(requestModel internalapi.RequestModel, modelNameOverride internalapi.ModelNameOverride) OpenAIEmbeddingTranslator { + return &openAIToGCPVertexAITranslatorV1Embedding{ + apiVersion: apiVersion, + openAIToOpenAITranslatorV1Embedding: openAIToOpenAITranslatorV1Embedding{ + modelNameOverride: modelNameOverride, + }, + } +} + +// openAIToGCPVertexAITranslatorV1Embedding implements [OpenAIEmbeddingTranslator] for /embeddings. +type openAIToGCPVertexAITranslatorV1Embedding[T openai.EmbeddingRequest] struct { + requestModel internalapi.RequestModel + openAIToOpenAITranslatorV1Embedding +} + + + +func InputToGeminiConent(input openai.EmbeddingRequestInput){ + switch v := input.Value.(type) { + case string: + + return v, "string", nil + case []string: + // Array of text inputs + return v, "string_array", nil + case []int64: + // Array of token IDs + return v, "token_array", nil + case [][]int64: + // Array of token ID arrays + return v, "token_array_batch", nil + default: + return nil, "unknown", fmt.Errorf("unsupported input type: %T", v) + } + + +} + +// openAIToGCPVertexAITranslatorV1Embedding converts an OpenAI EmbeddingRequest to a GCP Gemini GenerateContentRequest. +func openAIEmbeddingCompletionToGeminiMessage(openAIReq *openai.EmbeddingCompletionRequest, requestModel internalapi.RequestModel) (*gcp.EmbedContentRequest, error) { + // Convert OpenAI EmbeddingRequest's input to Gemini Contents + contents, err := InputToGeminiConent(openAIReq.Input, requestModel) + if err != nil { + return nil, err + } + + // Convert generation config. + embedConfig,, err := openAIReqToGeminiGenerationConfig(openAIReq, requestModel) + if err != nil { + return nil, fmt.Errorf("error converting generation config: %w", err) + } + + gcr := gcp.EmbedContentRequest{ + Contents: contents, + Config: embedConfig, + } + + return &gcr, nil +} + +// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody]. +func (o *openAIToGCPVertexAITranslatorV1Embedding[T]) RequestBody(original []byte, req *T, onRetry bool) ( + newHeaders []internalapi.Header, newBody []byte, err error, +) { + + o.requestModel = openai.GetModelFromEmbeddingRequest(req) + if o.modelNameOverride != "" { + // Use modelName override if set. + o.requestModel = o.modelNameOverride + } + + // Choose the correct endpoint based on streaming. + var path string + + path = buildGCPModelPathSuffix(gcpModelPublisherGoogle, o.requestModel, gcpMethodGenerateContent) + + switch any(*req).(type) { + case openai.EmbeddingCompletionRequest: + gcpReq, err := openAIEmbeddingCompletionToGeminiMessage(openAIReq, o.requestModel) + case openai.EmbeddingChatRequest: + gcpReq, err := openAIEmbeddingChatToGeminiMessage(openAIReq, o.requestModel) + + default: + return nil, nil, fmt.Errorf("request body is wrong: %w", err) + } + + newBody, err = json.Marshal(gcpReq) + if err != nil { + return nil, nil, fmt.Errorf("error marshaling Gemini request: %w", err) + } + newHeaders = []internalapi.Header{ + {pathHeaderName, path}, + {contentLengthHeaderName, strconv.Itoa(len(newBody))}, + } + return +} diff --git a/internal/translator/translator.go b/internal/translator/translator.go index 48c60943a..9a7c41b78 100644 --- a/internal/translator/translator.go +++ b/internal/translator/translator.go @@ -79,7 +79,7 @@ type ( // OpenAIChatCompletionTranslator translates the OpenAI's /chat/completions endpoint. OpenAIChatCompletionTranslator = Translator[openai.ChatCompletionRequest, tracing.ChatCompletionSpan] // OpenAIEmbeddingTranslator translates the OpenAI's /embeddings endpoint. - OpenAIEmbeddingTranslator = Translator[openai.EmbeddingRequest, tracing.EmbeddingsSpan] + OpenAIEmbeddingTranslator = Translator[openai.EmbeddingCompletionRequest, tracing.EmbeddingsSpan] // OpenAICompletionTranslator translates the OpenAI's /completions endpoint. OpenAICompletionTranslator = Translator[openai.CompletionRequest, tracing.CompletionSpan] // CohereRerankTranslator translates the Cohere's /v2/rerank endpoint.