Skip to content

Commit 06700b5

Browse files
committed
update
Signed-off-by: yxia216 <[email protected]>
1 parent 71b2589 commit 06700b5

File tree

8 files changed

+202
-31
lines changed

8 files changed

+202
-31
lines changed

internal/apischema/gcp/gcp.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,13 @@ type GenerateContentRequest struct {
3636
// https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L1057
3737
SafetySettings []*genai.SafetySetting `json:"safetySettings,omitempty"`
3838
}
39+
40+
type EmbedContentRequest struct {
41+
// Contains the multipart content of a message.
42+
//
43+
// https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L858
44+
Contents []genai.Content `json:"contents"`
45+
// Tool details of a tool that the model may use to generate a response.
46+
47+
Config *genai.EmbedContentConfig `json:"config,omitempty"`
48+
}

internal/apischema/openai/openai.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1520,6 +1520,11 @@ type EmbeddingCompletionRequest struct {
15201520
User *string `json:"user,omitempty"`
15211521
}
15221522

1523+
// GetModel implements ModelName interface
1524+
func (e *EmbeddingCompletionRequest) GetModel() string {
1525+
return e.Model
1526+
}
1527+
15231528
// EmbeddingChatRequest represents a request structure for embeddings API. This is not a standard openai, but just extend the request to have messages/chat like completion requests
15241529
type EmbeddingChatRequest struct {
15251530
// Messages: A list of messages comprising the conversation so far.
@@ -1545,10 +1550,28 @@ type EmbeddingChatRequest struct {
15451550
User *string `json:"user,omitempty"`
15461551
}
15471552

1548-
type EmbedddingRequest interface {
1553+
// GetModel implements ModelProvider interface
1554+
func (e *EmbeddingChatRequest) GetModel() string {
1555+
return e.Model
1556+
}
1557+
1558+
type EmbeddingRequest interface {
15491559
EmbeddingCompletionRequest | EmbeddingChatRequest
15501560
}
15511561

1562+
// ModelName interface for types that can provide a model name
1563+
type ModelName interface {
1564+
GetModel() string
1565+
}
1566+
1567+
// GetModelFromEmbeddingRequest extracts the model name from any EmbeddingRequest type
1568+
func GetModelFromEmbeddingRequest[T EmbeddingRequest](req *T) string {
1569+
if mp, ok := any(*req).(ModelName); ok {
1570+
return mp.GetModel()
1571+
}
1572+
return ""
1573+
}
1574+
15521575
// EmbeddingResponse represents a response from /v1/embeddings.
15531576
// https://platform.openai.com/docs/api-reference/embeddings/object
15541577
type EmbeddingResponse struct {

internal/extproc/embeddings_processor.go

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
3232
return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
3333
logger = logger.With("processor", "embeddings", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter))
3434
if !isUpstreamFilter {
35-
return &embeddingsProcessorRouterFilter{
35+
return &embeddingsProcessorRouterFilter[openai.EmbeddingCompletionRequest]{
3636
config: config,
3737
tracer: tracing.EmbeddingsTracer(),
3838
requestHeaders: requestHeaders,
3939
logger: logger,
4040
}, nil
4141
}
42-
return &embeddingsProcessorUpstreamFilter{
42+
return &embeddingsProcessorUpstreamFilter[openai.EmbeddingCompletionRequest]{
4343
config: config,
4444
requestHeaders: requestHeaders,
4545
logger: logger,
@@ -51,7 +51,7 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
5151
// embeddingsProcessorRouterFilter implements [Processor] for the `/v1/embeddings` endpoint.
5252
//
5353
// This is primarily used to select the route for the request based on the model name.
54-
type embeddingsProcessorRouterFilter struct {
54+
type embeddingsProcessorRouterFilter[T openai.EmbeddingRequest] struct {
5555
passThroughProcessor
5656
// upstreamFilter is the upstream filter that is used to process the request at the upstream filter.
5757
// This will be updated when the request is retried.
@@ -67,7 +67,7 @@ type embeddingsProcessorRouterFilter struct {
6767
// originalRequestBody is the original request body that is passed to the upstream filter.
6868
// This is used to perform the transformation of the request body on the original input
6969
// when the request is retried.
70-
originalRequestBody *openai.EmbeddingRequest
70+
originalRequestBody *T
7171
originalRequestBodyRaw []byte
7272
// tracer is the tracer used for requests.
7373
tracer tracing.EmbeddingsTracer
@@ -79,7 +79,7 @@ type embeddingsProcessorRouterFilter struct {
7979
}
8080

8181
// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
82-
func (e *embeddingsProcessorRouterFilter) ProcessResponseHeaders(ctx context.Context, headerMap *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
82+
func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseHeaders(ctx context.Context, headerMap *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
8383
// If the request failed to route and/or immediate response was returned before the upstream filter was set,
8484
// e.upstreamFilter can be nil.
8585
if e.upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
@@ -89,7 +89,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessResponseHeaders(ctx context.Con
8989
}
9090

9191
// ProcessResponseBody implements [Processor.ProcessResponseBody].
92-
func (e *embeddingsProcessorRouterFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
92+
func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
9393
// If the request failed to route and/or immediate response was returned before the upstream filter was set,
9494
// e.upstreamFilter can be nil.
9595
if e.upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
@@ -99,8 +99,8 @@ func (e *embeddingsProcessorRouterFilter) ProcessResponseBody(ctx context.Contex
9999
}
100100

101101
// ProcessRequestBody implements [Processor.ProcessRequestBody].
102-
func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
103-
originalModel, body, err := parseOpenAIEmbeddingBody(rawBody)
102+
func (e *embeddingsProcessorRouterFilter[T]) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
103+
originalModel, body, err := parseOpenAIEmbeddingBody[T](rawBody)
104104
if err != nil {
105105
return nil, fmt.Errorf("failed to parse request body: %w", err)
106106
}
@@ -125,7 +125,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context
125125
ctx,
126126
e.requestHeaders,
127127
&headerMutationCarrier{m: headerMutation},
128-
body,
128+
convertToEmbeddingCompletionRequest(body),
129129
rawBody.Body,
130130
)
131131

@@ -144,7 +144,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context
144144
// embeddingsProcessorUpstreamFilter implements [Processor] for the `/v1/embeddings` endpoint at the upstream filter.
145145
//
146146
// This is created per retry and handles the translation as well as the authentication of the request.
147-
type embeddingsProcessorUpstreamFilter struct {
147+
type embeddingsProcessorUpstreamFilter[T openai.EmbeddingRequest] struct {
148148
logger *slog.Logger
149149
config *filterapi.RuntimeConfig
150150
requestHeaders map[string]string
@@ -156,7 +156,7 @@ type embeddingsProcessorUpstreamFilter struct {
156156
headerMutator *headermutator.HeaderMutator
157157
bodyMutator *bodymutator.BodyMutator
158158
originalRequestBodyRaw []byte
159-
originalRequestBody *openai.EmbeddingRequest
159+
originalRequestBody *T
160160
translator translator.OpenAIEmbeddingTranslator
161161
// onRetry is true if this is a retry request at the upstream filter.
162162
onRetry bool
@@ -169,12 +169,14 @@ type embeddingsProcessorUpstreamFilter struct {
169169
}
170170

171171
// selectTranslator selects the translator based on the output schema.
172-
func (e *embeddingsProcessorUpstreamFilter) selectTranslator(out filterapi.VersionedAPISchema) error {
172+
func (e *embeddingsProcessorUpstreamFilter[T]) selectTranslator(out filterapi.VersionedAPISchema) error {
173173
switch out.Name {
174174
case filterapi.APISchemaOpenAI:
175175
e.translator = translator.NewEmbeddingOpenAIToOpenAITranslator(out.Version, e.modelNameOverride)
176176
case filterapi.APISchemaAzureOpenAI:
177177
e.translator = translator.NewEmbeddingOpenAIToAzureOpenAITranslator(out.Version, e.modelNameOverride)
178+
case filterapi.APISchemaGCPVertexAI:
179+
e.translator = translator.NewEmbeddingOpenAIToAzureOpenAITranslator(out.Version, e.modelNameOverride)
178180
default:
179181
return fmt.Errorf("unsupported API schema: backend=%s", out)
180182
}
@@ -187,7 +189,7 @@ func (e *embeddingsProcessorUpstreamFilter) selectTranslator(out filterapi.Versi
187189
// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy
188190
// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again
189191
// to the extproc.
190-
func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
192+
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
191193
defer func() {
192194
if err != nil {
193195
e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders)
@@ -197,12 +199,12 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co
197199
// Start tracking metrics for this request.
198200
e.metrics.StartRequest(e.requestHeaders)
199201
// Set the original model from the request body before any overrides
200-
e.metrics.SetOriginalModel(e.originalRequestBody.Model)
202+
e.metrics.SetOriginalModel(openai.GetModelFromEmbeddingRequest(e.originalRequestBody))
201203
// Set the request model for metrics from the original model or override if applied.
202-
reqModel := cmp.Or(e.requestHeaders[internalapi.ModelNameHeaderKeyDefault], e.originalRequestBody.Model)
204+
reqModel := cmp.Or(e.requestHeaders[internalapi.ModelNameHeaderKeyDefault], openai.GetModelFromEmbeddingRequest(e.originalRequestBody))
203205
e.metrics.SetRequestModel(reqModel)
204206

205-
newHeaders, newBody, err := e.translator.RequestBody(e.originalRequestBodyRaw, e.originalRequestBody, e.onRetry)
207+
newHeaders, newBody, err := e.translator.RequestBody(e.originalRequestBodyRaw, convertToEmbeddingCompletionRequest(e.originalRequestBody), e.onRetry)
206208
if err != nil {
207209
return nil, fmt.Errorf("failed to transform request: %w", err)
208210
}
@@ -265,12 +267,12 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co
265267
}
266268

267269
// ProcessRequestBody implements [Processor.ProcessRequestBody].
268-
func (e *embeddingsProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
270+
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
269271
panic("BUG: ProcessRequestBody should not be called in the upstream filter")
270272
}
271273

272274
// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
273-
func (e *embeddingsProcessorUpstreamFilter) ProcessResponseHeaders(ctx context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
275+
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseHeaders(ctx context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
274276
defer func() {
275277
if err != nil {
276278
e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders)
@@ -294,7 +296,7 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseHeaders(ctx context.C
294296
}
295297

296298
// ProcessResponseBody implements [Processor.ProcessResponseBody].
297-
func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
299+
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
298300
recordRequestCompletionErr := false
299301
defer func() {
300302
if err != nil || recordRequestCompletionErr {
@@ -383,13 +385,13 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Cont
383385
}
384386

385387
// SetBackend implements [Processor.SetBackend].
386-
func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) {
388+
func (e *embeddingsProcessorUpstreamFilter[T]) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) {
387389
defer func() {
388390
if err != nil {
389391
e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders)
390392
}
391393
}()
392-
rp, ok := routeProcessor.(*embeddingsProcessorRouterFilter)
394+
rp, ok := routeProcessor.(*embeddingsProcessorRouterFilter[T])
393395
if !ok {
394396
panic("BUG: expected routeProcessor to be of type *embeddingsProcessorRouterFilter")
395397
}
@@ -417,10 +419,31 @@ func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *f
417419
return
418420
}
419421

420-
func parseOpenAIEmbeddingBody(body *extprocv3.HttpBody) (modelName string, rb *openai.EmbeddingRequest, err error) {
421-
var openAIReq openai.EmbeddingRequest
422+
// convertToEmbeddingCompletionRequest converts any EmbeddingRequest to EmbeddingCompletionRequest for compatibility
423+
func convertToEmbeddingCompletionRequest[T openai.EmbeddingRequest](req *T) *openai.EmbeddingCompletionRequest {
424+
switch r := any(*req).(type) {
425+
case openai.EmbeddingCompletionRequest:
426+
return &r
427+
case openai.EmbeddingChatRequest:
428+
// Convert EmbeddingChatRequest to EmbeddingCompletionRequest by flattening messages to input
429+
// This is a simplified conversion - in practice you might need more sophisticated logic
430+
return &openai.EmbeddingCompletionRequest{
431+
Model: r.Model,
432+
Input: openai.EmbeddingRequestInput{Value: "converted_from_chat"}, // Simplified
433+
EncodingFormat: r.EncodingFormat,
434+
Dimensions: r.Dimensions,
435+
User: r.User,
436+
}
437+
default:
438+
return &openai.EmbeddingCompletionRequest{}
439+
}
440+
}
441+
442+
func parseOpenAIEmbeddingBody[T openai.EmbeddingRequest](body *extprocv3.HttpBody) (modelName string, rb *T, err error) {
443+
var openAIReq T
422444
if err := json.Unmarshal(body.Body, &openAIReq); err != nil {
423445
return "", nil, fmt.Errorf("failed to unmarshal body: %w", err)
424446
}
425-
return openAIReq.Model, &openAIReq, nil
447+
448+
return openai.GetModelFromEmbeddingRequest(&openAIReq), &openAIReq, nil
426449
}

internal/tracing/api/api.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ type (
5656
// CompletionTracer creates spans for OpenAI completion requests.
5757
CompletionTracer = RequestTracer[openai.CompletionRequest, CompletionSpan]
5858
// EmbeddingsTracer creates spans for OpenAI embeddings requests.
59-
EmbeddingsTracer = RequestTracer[openai.EmbeddingRequest, EmbeddingsSpan]
59+
EmbeddingsTracer = RequestTracer[openai.EmbeddingCompletionRequest, EmbeddingsSpan]
6060
// ImageGenerationTracer creates spans for OpenAI image generation requests.
6161
ImageGenerationTracer = RequestTracer[openaisdk.ImageGenerateParams, ImageGenerationSpan]
6262
// RerankTracer creates spans for rerank requests.
@@ -116,7 +116,7 @@ type (
116116
// ImageGenerationRecorder records attributes to a span according to a semantic convention.
117117
ImageGenerationRecorder = SpanRecorder[openaisdk.ImageGenerateParams, struct{}, openaisdk.ImagesResponse]
118118
// EmbeddingsRecorder records attributes to a span according to a semantic convention.
119-
EmbeddingsRecorder = SpanRecorder[openai.EmbeddingRequest, struct{}, openai.EmbeddingResponse]
119+
EmbeddingsRecorder = SpanRecorder[openai.EmbeddingCompletionRequest, struct{}, openai.EmbeddingResponse]
120120
// RerankRecorder records attributes to a span according to a semantic convention.
121121
RerankRecorder = SpanRecorder[cohere.RerankV2Request, struct{}, cohere.RerankV2Response]
122122
)
@@ -146,7 +146,7 @@ func (NoopTracing) CompletionTracer() CompletionTracer {
146146

147147
// EmbeddingsTracer implements Tracing.EmbeddingsTracer.
148148
func (NoopTracing) EmbeddingsTracer() EmbeddingsTracer {
149-
return NoopTracer[openai.EmbeddingRequest, EmbeddingsSpan]{}
149+
return NoopTracer[openai.EmbeddingCompletionRequest, EmbeddingsSpan]{}
150150
}
151151

152152
// ImageGenerationTracer implements Tracing.ImageGenerationTracer.

internal/translator/openai_azureopenai_embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type openAIToAzureOpenAITranslatorV1Embedding struct {
3333
}
3434

3535
// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody].
36-
func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingRequest, onRetry bool) (
36+
func (o *openAIToAzureOpenAITranslatorV1Embedding) RequestBody(original []byte, req *openai.EmbeddingCompletionRequest, onRetry bool) (
3737
newHeaders []internalapi.Header, newBody []byte, err error,
3838
) {
3939
modelName := req.Model

internal/translator/openai_embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ type openAIToOpenAITranslatorV1Embedding struct {
3535
}
3636

3737
// RequestBody implements [OpenAIEmbeddingTranslator.RequestBody].
38-
func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, _ *openai.EmbeddingRequest, onRetry bool) (
38+
func (o *openAIToOpenAITranslatorV1Embedding) RequestBody(original []byte, _ *openai.EmbeddingCompletionRequest, onRetry bool) (
3939
newHeaders []internalapi.Header, newBody []byte, err error,
4040
) {
4141
if o.modelNameOverride != "" {

0 commit comments

Comments
 (0)