Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions internal/apischema/gcp/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,13 @@ type GenerateContentRequest struct {
// https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L1057
SafetySettings []*genai.SafetySetting `json:"safetySettings,omitempty"`
}

type EmbedContentRequest struct {
// Contains the multipart content of a message.
//
// https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L858
Contents []genai.Content `json:"contents"`
// Tool details of a tool that the model may use to generate a response.

Config *genai.EmbedContentConfig `json:"config,omitempty"`
}
56 changes: 54 additions & 2 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -1493,8 +1493,8 @@ type Model struct {
OwnedBy string `json:"owned_by"`
}

// EmbeddingRequest represents a request structure for embeddings API.
type EmbeddingRequest struct {
// EmbeddingCompletionRequest represents a request structure for embeddings API.
type EmbeddingCompletionRequest struct {
// Input: Input text to embed, encoded as a string or array of tokens.
// To embed multiple inputs in a single request, pass an array of strings or array of token arrays.
// The input must not exceed the max input tokens for the model (8192 tokens for text-embedding-ada-002),
Expand All @@ -1520,6 +1520,58 @@ type EmbeddingRequest struct {
User *string `json:"user,omitempty"`
}

// GetModel implements ModelName interface
func (e *EmbeddingCompletionRequest) GetModel() string {
return e.Model
}

// EmbeddingChatRequest represents a request structure for embeddings API. This is not a standard openai, but just extend the request to have messages/chat like completion requests
type EmbeddingChatRequest struct {
// Messages: A list of messages comprising the conversation so far.
// Depending on the model you use, different message types (modalities) are supported,
// like text, images, and audio.
Messages []ChatCompletionMessageParamUnion `json:"messages"`

// Model: ID of the model to use.
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-model
Model string `json:"model"`

// EncodingFormat: The format to return the embeddings in. Can be either float or base64.
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-encoding_format
EncodingFormat *string `json:"encoding_format,omitempty"` //nolint:tagliatelle //follow openai api

// Dimensions: The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions
Dimensions *int `json:"dimensions,omitempty"`

// User: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-user
User *string `json:"user,omitempty"`
}

// GetModel implements ModelProvider interface
func (e *EmbeddingChatRequest) GetModel() string {
return e.Model
}

type EmbeddingRequest interface {
EmbeddingCompletionRequest | EmbeddingChatRequest
}

// ModelName interface for types that can provide a model name
type ModelName interface {
GetModel() string
}

// GetModelFromEmbeddingRequest extracts the model name from any EmbeddingRequest type
func GetModelFromEmbeddingRequest[T EmbeddingRequest](req *T) string {
if mp, ok := any(*req).(ModelName); ok {
return mp.GetModel()
}
return ""
}

// EmbeddingResponse represents a response from /v1/embeddings.
// https://platform.openai.com/docs/api-reference/embeddings/object
type EmbeddingResponse struct {
Expand Down
71 changes: 47 additions & 24 deletions internal/extproc/embeddings_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
logger = logger.With("processor", "embeddings", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter))
if !isUpstreamFilter {
return &embeddingsProcessorRouterFilter{
return &embeddingsProcessorRouterFilter[openai.EmbeddingCompletionRequest]{
config: config,
tracer: tracing.EmbeddingsTracer(),
requestHeaders: requestHeaders,
logger: logger,
}, nil
}
return &embeddingsProcessorUpstreamFilter{
return &embeddingsProcessorUpstreamFilter[openai.EmbeddingCompletionRequest]{
config: config,
requestHeaders: requestHeaders,
logger: logger,
Expand All @@ -51,7 +51,7 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
// embeddingsProcessorRouterFilter implements [Processor] for the `/v1/embeddings` endpoint.
//
// This is primarily used to select the route for the request based on the model name.
type embeddingsProcessorRouterFilter struct {
type embeddingsProcessorRouterFilter[T openai.EmbeddingRequest] struct {
passThroughProcessor
// upstreamFilter is the upstream filter that is used to process the request at the upstream filter.
// This will be updated when the request is retried.
Expand All @@ -67,7 +67,7 @@ type embeddingsProcessorRouterFilter struct {
// originalRequestBody is the original request body that is passed to the upstream filter.
// This is used to perform the transformation of the request body on the original input
// when the request is retried.
originalRequestBody *openai.EmbeddingRequest
originalRequestBody *T
originalRequestBodyRaw []byte
// tracer is the tracer used for requests.
tracer tracing.EmbeddingsTracer
Expand All @@ -79,7 +79,7 @@ type embeddingsProcessorRouterFilter struct {
}

// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
func (e *embeddingsProcessorRouterFilter) ProcessResponseHeaders(ctx context.Context, headerMap *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseHeaders(ctx context.Context, headerMap *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
// If the request failed to route and/or immediate response was returned before the upstream filter was set,
// e.upstreamFilter can be nil.
if e.upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
Expand All @@ -89,7 +89,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessResponseHeaders(ctx context.Con
}

// ProcessResponseBody implements [Processor.ProcessResponseBody].
func (e *embeddingsProcessorRouterFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
// If the request failed to route and/or immediate response was returned before the upstream filter was set,
// e.upstreamFilter can be nil.
if e.upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
Expand All @@ -99,8 +99,8 @@ func (e *embeddingsProcessorRouterFilter) ProcessResponseBody(ctx context.Contex
}

// ProcessRequestBody implements [Processor.ProcessRequestBody].
func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
originalModel, body, err := parseOpenAIEmbeddingBody(rawBody)
func (e *embeddingsProcessorRouterFilter[T]) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
originalModel, body, err := parseOpenAIEmbeddingBody[T](rawBody)
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %w", err)
}
Expand All @@ -125,7 +125,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context
ctx,
e.requestHeaders,
&headerMutationCarrier{m: headerMutation},
body,
convertToEmbeddingCompletionRequest(body),
rawBody.Body,
)

Expand All @@ -144,7 +144,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context
// embeddingsProcessorUpstreamFilter implements [Processor] for the `/v1/embeddings` endpoint at the upstream filter.
//
// This is created per retry and handles the translation as well as the authentication of the request.
type embeddingsProcessorUpstreamFilter struct {
type embeddingsProcessorUpstreamFilter[T openai.EmbeddingRequest] struct {
logger *slog.Logger
config *filterapi.RuntimeConfig
requestHeaders map[string]string
Expand All @@ -156,7 +156,7 @@ type embeddingsProcessorUpstreamFilter struct {
headerMutator *headermutator.HeaderMutator
bodyMutator *bodymutator.BodyMutator
originalRequestBodyRaw []byte
originalRequestBody *openai.EmbeddingRequest
originalRequestBody *T
translator translator.OpenAIEmbeddingTranslator
// onRetry is true if this is a retry request at the upstream filter.
onRetry bool
Expand All @@ -169,12 +169,14 @@ type embeddingsProcessorUpstreamFilter struct {
}

// selectTranslator selects the translator based on the output schema.
func (e *embeddingsProcessorUpstreamFilter) selectTranslator(out filterapi.VersionedAPISchema) error {
func (e *embeddingsProcessorUpstreamFilter[T]) selectTranslator(out filterapi.VersionedAPISchema) error {
switch out.Name {
case filterapi.APISchemaOpenAI:
e.translator = translator.NewEmbeddingOpenAIToOpenAITranslator(out.Version, e.modelNameOverride)
case filterapi.APISchemaAzureOpenAI:
e.translator = translator.NewEmbeddingOpenAIToAzureOpenAITranslator(out.Version, e.modelNameOverride)
case filterapi.APISchemaGCPVertexAI:
e.translator = translator.NewEmbeddingOpenAIToAzureOpenAITranslator(out.Version, e.modelNameOverride)
default:
return fmt.Errorf("unsupported API schema: backend=%s", out)
}
Expand All @@ -187,7 +189,7 @@ func (e *embeddingsProcessorUpstreamFilter) selectTranslator(out filterapi.Versi
// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy
// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again
// to the extproc.
func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
defer func() {
if err != nil {
e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders)
Expand All @@ -197,12 +199,12 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co
// Start tracking metrics for this request.
e.metrics.StartRequest(e.requestHeaders)
// Set the original model from the request body before any overrides
e.metrics.SetOriginalModel(e.originalRequestBody.Model)
e.metrics.SetOriginalModel(openai.GetModelFromEmbeddingRequest(e.originalRequestBody))
// Set the request model for metrics from the original model or override if applied.
reqModel := cmp.Or(e.requestHeaders[internalapi.ModelNameHeaderKeyDefault], e.originalRequestBody.Model)
reqModel := cmp.Or(e.requestHeaders[internalapi.ModelNameHeaderKeyDefault], openai.GetModelFromEmbeddingRequest(e.originalRequestBody))
e.metrics.SetRequestModel(reqModel)

newHeaders, newBody, err := e.translator.RequestBody(e.originalRequestBodyRaw, e.originalRequestBody, e.onRetry)
newHeaders, newBody, err := e.translator.RequestBody(e.originalRequestBodyRaw, convertToEmbeddingCompletionRequest(e.originalRequestBody), e.onRetry)
if err != nil {
return nil, fmt.Errorf("failed to transform request: %w", err)
}
Expand Down Expand Up @@ -265,12 +267,12 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co
}

// ProcessRequestBody implements [Processor.ProcessRequestBody].
func (e *embeddingsProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
panic("BUG: ProcessRequestBody should not be called in the upstream filter")
}

// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
func (e *embeddingsProcessorUpstreamFilter) ProcessResponseHeaders(ctx context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseHeaders(ctx context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
defer func() {
if err != nil {
e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders)
Expand All @@ -294,7 +296,7 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseHeaders(ctx context.C
}

// ProcessResponseBody implements [Processor.ProcessResponseBody].
func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
recordRequestCompletionErr := false
defer func() {
if err != nil || recordRequestCompletionErr {
Expand Down Expand Up @@ -383,13 +385,13 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Cont
}

// SetBackend implements [Processor.SetBackend].
func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) {
func (e *embeddingsProcessorUpstreamFilter[T]) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) {
defer func() {
if err != nil {
e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders)
}
}()
rp, ok := routeProcessor.(*embeddingsProcessorRouterFilter)
rp, ok := routeProcessor.(*embeddingsProcessorRouterFilter[T])
if !ok {
panic("BUG: expected routeProcessor to be of type *embeddingsProcessorRouterFilter")
}
Expand Down Expand Up @@ -417,10 +419,31 @@ func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *f
return
}

func parseOpenAIEmbeddingBody(body *extprocv3.HttpBody) (modelName string, rb *openai.EmbeddingRequest, err error) {
var openAIReq openai.EmbeddingRequest
// convertToEmbeddingCompletionRequest converts any EmbeddingRequest to EmbeddingCompletionRequest for compatibility
func convertToEmbeddingCompletionRequest[T openai.EmbeddingRequest](req *T) *openai.EmbeddingCompletionRequest {
switch r := any(*req).(type) {
case openai.EmbeddingCompletionRequest:
return &r
case openai.EmbeddingChatRequest:
// Convert EmbeddingChatRequest to EmbeddingCompletionRequest by flattening messages to input
// This is a simplified conversion - in practice you might need more sophisticated logic
return &openai.EmbeddingCompletionRequest{
Model: r.Model,
Input: openai.EmbeddingRequestInput{Value: "converted_from_chat"}, // Simplified
EncodingFormat: r.EncodingFormat,
Dimensions: r.Dimensions,
User: r.User,
}
default:
return &openai.EmbeddingCompletionRequest{}
}
}

func parseOpenAIEmbeddingBody[T openai.EmbeddingRequest](body *extprocv3.HttpBody) (modelName string, rb *T, err error) {
var openAIReq T
if err := json.Unmarshal(body.Body, &openAIReq); err != nil {
return "", nil, fmt.Errorf("failed to unmarshal body: %w", err)
}
return openAIReq.Model, &openAIReq, nil

return openai.GetModelFromEmbeddingRequest(&openAIReq), &openAIReq, nil
}
6 changes: 3 additions & 3 deletions internal/tracing/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type (
// CompletionTracer creates spans for OpenAI completion requests.
CompletionTracer = RequestTracer[openai.CompletionRequest, CompletionSpan]
// EmbeddingsTracer creates spans for OpenAI embeddings requests.
EmbeddingsTracer = RequestTracer[openai.EmbeddingRequest, EmbeddingsSpan]
EmbeddingsTracer = RequestTracer[openai.EmbeddingCompletionRequest, EmbeddingsSpan]
// ImageGenerationTracer creates spans for OpenAI image generation requests.
ImageGenerationTracer = RequestTracer[openaisdk.ImageGenerateParams, ImageGenerationSpan]
// RerankTracer creates spans for rerank requests.
Expand Down Expand Up @@ -116,7 +116,7 @@ type (
// ImageGenerationRecorder records attributes to a span according to a semantic convention.
ImageGenerationRecorder = SpanRecorder[openaisdk.ImageGenerateParams, struct{}, openaisdk.ImagesResponse]
// EmbeddingsRecorder records attributes to a span according to a semantic convention.
EmbeddingsRecorder = SpanRecorder[openai.EmbeddingRequest, struct{}, openai.EmbeddingResponse]
EmbeddingsRecorder = SpanRecorder[openai.EmbeddingCompletionRequest, struct{}, openai.EmbeddingResponse]
// RerankRecorder records attributes to a span according to a semantic convention.
RerankRecorder = SpanRecorder[cohere.RerankV2Request, struct{}, cohere.RerankV2Response]
)
Expand Down Expand Up @@ -146,7 +146,7 @@ func (NoopTracing) CompletionTracer() CompletionTracer {

// EmbeddingsTracer implements Tracing.EmbeddingsTracer.
func (NoopTracing) EmbeddingsTracer() EmbeddingsTracer {
return NoopTracer[openai.EmbeddingRequest, EmbeddingsSpan]{}
return NoopTracer[openai.EmbeddingCompletionRequest, EmbeddingsSpan]{}
}

// ImageGenerationTracer implements Tracing.ImageGenerationTracer.
Expand Down
2 changes: 1 addition & 1 deletion internal/translator/openai_azureopenai_embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type openAIToAzureOpenAITranslatorV1Embedding struct {
}

// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody].
func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingRequest, onRetry bool) (
func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingCompletionRequest, onRetry bool) (
newHeaders []internalapi.Header, newBody []byte, err error,
) {
modelName := req.Model
Expand Down
2 changes: 1 addition & 1 deletion internal/translator/openai_embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type openAIToOpenAITranslatorV1Embedding struct {
}

// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody].
func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, _ *openai.EmbeddingRequest, onRetry bool) (
func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, _ *openai.EmbeddingCompletionRequest, onRetry bool) (
newHeaders []internalapi.Header, newBody []byte, err error,
) {
if o.modelNameOverride != "" {
Expand Down
Loading