Skip to content

Commit e3251b7

Browse files
committed
update
Signed-off-by: yxia216 <[email protected]>
1 parent 06700b5 commit e3251b7

22 files changed

+1484
-195
lines changed

internal/apischema/gcp/gcp.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ type GenerateContentRequest struct {
3838
}
3939

4040
type EmbedContentRequest struct {
41-
// Contains the multipart content of a message.
42-
//
43-
// https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L858
44-
Contents []genai.Content `json:"contents"`
45-
// Tool details of a tool that the model may use to generate a response.
41+
// Content to be embedded. Only text content is supported for embeddings.
42+
Content *genai.Content `json:"content"`
4643

44+
// Optional configuration for the embedding request.
45+
// Uses the official genai library configuration structure.
4746
Config *genai.EmbedContentConfig `json:"config,omitempty"`
4847
}
48+
49+
// Note: We now use genai.EmbedContentResponse directly instead of defining our own.
50+
// This provides better compatibility and includes metadata like token usage.

internal/apischema/openai/openai.go

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,9 @@ type EmbeddingCompletionRequest struct {
15181518
// User: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
15191519
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-user
15201520
User *string `json:"user,omitempty"`
1521+
1522+
// GCPVertexAIEmbeddingVendorFields configures the GCP VertexAI specific fields during schema translation.
1523+
*GCPVertexAIEmbeddingVendorFields `json:",inline,omitempty"`
15211524
}
15221525

15231526
// GetModel implements ModelName interface
@@ -1548,26 +1551,67 @@ type EmbeddingChatRequest struct {
15481551
// User: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
15491552
// Docs: https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-user
15501553
User *string `json:"user,omitempty"`
1554+
1555+
// GCPVertexAIEmbeddingVendorFields configures the GCP VertexAI specific fields during schema translation.
1556+
*GCPVertexAIEmbeddingVendorFields `json:",inline,omitempty"`
15511557
}
15521558

15531559
// GetModel implements ModelProvider interface
15541560
func (e *EmbeddingChatRequest) GetModel() string {
15551561
return e.Model
15561562
}
15571563

1558-
type EmbeddingRequest interface {
1559-
EmbeddingCompletionRequest | EmbeddingChatRequest
1564+
// EmbeddingRequest is a union type that can handle both EmbeddingCompletionRequest and EmbeddingChatRequest.
1565+
type EmbeddingRequest struct {
1566+
OfCompletion *EmbeddingCompletionRequest `json:",omitzero,inline"`
1567+
OfChat *EmbeddingChatRequest `json:",omitzero,inline"`
1568+
}
1569+
1570+
// UnmarshalJSON implements json.Unmarshaler to handle both EmbeddingCompletionRequest and EmbeddingChatRequest.
1571+
func (e *EmbeddingRequest) UnmarshalJSON(data []byte) error {
1572+
// Check for Messages field to distinguish EmbeddingChatRequest
1573+
messagesResult := gjson.GetBytes(data, "messages")
1574+
if messagesResult.Exists() {
1575+
var chatReq EmbeddingChatRequest
1576+
if err := json.Unmarshal(data, &chatReq); err != nil {
1577+
return err
1578+
}
1579+
e.OfChat = &chatReq
1580+
return nil
1581+
}
1582+
1583+
// Check for Input field to distinguish EmbeddingCompletionRequest
1584+
inputResult := gjson.GetBytes(data, "input")
1585+
if inputResult.Exists() {
1586+
var completionReq EmbeddingCompletionRequest
1587+
if err := json.Unmarshal(data, &completionReq); err != nil {
1588+
return err
1589+
}
1590+
e.OfCompletion = &completionReq
1591+
return nil
1592+
}
1593+
1594+
return errors.New("embedding request must have either 'input' field (EmbeddingCompletionRequest) or 'messages' field (EmbeddingChatRequest)")
15601595
}
15611596

1562-
// ModelName interface for types that can provide a model name
1563-
type ModelName interface {
1564-
GetModel() string
1597+
// MarshalJSON implements json.Marshaler.
1598+
func (e EmbeddingRequest) MarshalJSON() ([]byte, error) {
1599+
if e.OfCompletion != nil {
1600+
return json.Marshal(e.OfCompletion)
1601+
}
1602+
if e.OfChat != nil {
1603+
return json.Marshal(e.OfChat)
1604+
}
1605+
return nil, errors.New("no embedding request to marshal")
15651606
}
15661607

15671608
// GetModelFromEmbeddingRequest extracts the model name from any EmbeddingRequest type
1568-
func GetModelFromEmbeddingRequest[T EmbeddingRequest](req *T) string {
1569-
if mp, ok := any(*req).(ModelName); ok {
1570-
return mp.GetModel()
1609+
func GetModelFromEmbeddingRequest(req *EmbeddingRequest) string {
1610+
if req.OfCompletion != nil {
1611+
return req.OfCompletion.GetModel()
1612+
}
1613+
if req.OfChat != nil {
1614+
return req.OfChat.GetModel()
15711615
}
15721616
return ""
15731617
}
@@ -1645,6 +1689,13 @@ type EmbeddingUsage struct {
16451689
TotalTokens int `json:"total_tokens"` //nolint:tagliatelle //follow openai api
16461690
}
16471691

1692+
// GCPVertexAIEmbeddingVendorFields contains GCP Vertex AI (Gemini) vendor-specific fields for embedding requests.
1693+
type GCPVertexAIEmbeddingVendorFields struct {
1694+
// Type of task for which the embedding will be used.
1695+
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types#supported_task_types
1696+
TaskType string `json:"task_type,omitempty"`
1697+
}
1698+
16481699
// JSONUNIXTime is a helper type to marshal/unmarshal time.Time UNIX timestamps.
16491700
type JSONUNIXTime time.Time
16501701

internal/extproc/embeddings_processor.go

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
3232
return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
3333
logger = logger.With("processor", "embeddings", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter))
3434
if !isUpstreamFilter {
35-
return &embeddingsProcessorRouterFilter[openai.EmbeddingCompletionRequest]{
35+
return &embeddingsProcessorRouterFilter{
3636
config: config,
3737
tracer: tracing.EmbeddingsTracer(),
3838
requestHeaders: requestHeaders,
3939
logger: logger,
4040
}, nil
4141
}
42-
return &embeddingsProcessorUpstreamFilter[openai.EmbeddingCompletionRequest]{
42+
return &embeddingsProcessorUpstreamFilter{
4343
config: config,
4444
requestHeaders: requestHeaders,
4545
logger: logger,
@@ -51,7 +51,7 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
5151
// embeddingsProcessorRouterFilter implements [Processor] for the `/v1/embeddings` endpoint.
5252
//
5353
// This is primarily used to select the route for the request based on the model name.
54-
type embeddingsProcessorRouterFilter[T openai.EmbeddingRequest] struct {
54+
type embeddingsProcessorRouterFilter struct {
5555
passThroughProcessor
5656
// upstreamFilter is the upstream filter that is used to process the request at the upstream filter.
5757
// This will be updated when the request is retried.
@@ -67,7 +67,7 @@ type embeddingsProcessorRouterFilter[T openai.EmbeddingRequest] struct {
6767
// originalRequestBody is the original request body that is passed to the upstream filter.
6868
// This is used to perform the transformation of the request body on the original input
6969
// when the request is retried.
70-
originalRequestBody *T
70+
originalRequestBody *openai.EmbeddingRequest
7171
originalRequestBodyRaw []byte
7272
// tracer is the tracer used for requests.
7373
tracer tracing.EmbeddingsTracer
@@ -79,7 +79,7 @@ type embeddingsProcessorRouterFilter[T openai.EmbeddingRequest] struct {
7979
}
8080

8181
// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
82-
func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseHeaders(ctx context.Context, headerMap *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
82+
func (e *embeddingsProcessorRouterFilter) ProcessResponseHeaders(ctx context.Context, headerMap *corev3.HeaderMap) (*extprocv3.ProcessingResponse, error) {
8383
// If the request failed to route and/or immediate response was returned before the upstream filter was set,
8484
// e.upstreamFilter can be nil.
8585
if e.upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
@@ -89,7 +89,7 @@ func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseHeaders(ctx context.
8989
}
9090

9191
// ProcessResponseBody implements [Processor.ProcessResponseBody].
92-
func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
92+
func (e *embeddingsProcessorRouterFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
9393
// If the request failed to route and/or immediate response was returned before the upstream filter was set,
9494
// e.upstreamFilter can be nil.
9595
if e.upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
@@ -99,8 +99,8 @@ func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseBody(ctx context.Con
9999
}
100100

101101
// ProcessRequestBody implements [Processor.ProcessRequestBody].
102-
func (e *embeddingsProcessorRouterFilter[T]) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
103-
originalModel, body, err := parseOpenAIEmbeddingBody[T](rawBody)
102+
func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (*extprocv3.ProcessingResponse, error) {
103+
originalModel, body, err := parseOpenAIEmbeddingBody(rawBody)
104104
if err != nil {
105105
return nil, fmt.Errorf("failed to parse request body: %w", err)
106106
}
@@ -125,7 +125,7 @@ func (e *embeddingsProcessorRouterFilter[T]) ProcessRequestBody(ctx context.Cont
125125
ctx,
126126
e.requestHeaders,
127127
&headerMutationCarrier{m: headerMutation},
128-
convertToEmbeddingCompletionRequest(body),
128+
body,
129129
rawBody.Body,
130130
)
131131

@@ -144,7 +144,7 @@ func (e *embeddingsProcessorRouterFilter[T]) ProcessRequestBody(ctx context.Cont
144144
// embeddingsProcessorUpstreamFilter implements [Processor] for the `/v1/embeddings` endpoint at the upstream filter.
145145
//
146146
// This is created per retry and handles the translation as well as the authentication of the request.
147-
type embeddingsProcessorUpstreamFilter[T openai.EmbeddingRequest] struct {
147+
type embeddingsProcessorUpstreamFilter struct {
148148
logger *slog.Logger
149149
config *filterapi.RuntimeConfig
150150
requestHeaders map[string]string
@@ -156,7 +156,7 @@ type embeddingsProcessorUpstreamFilter[T openai.EmbeddingRequest] struct {
156156
headerMutator *headermutator.HeaderMutator
157157
bodyMutator *bodymutator.BodyMutator
158158
originalRequestBodyRaw []byte
159-
originalRequestBody *T
159+
originalRequestBody *openai.EmbeddingRequest
160160
translator translator.OpenAIEmbeddingTranslator
161161
// onRetry is true if this is a retry request at the upstream filter.
162162
onRetry bool
@@ -169,14 +169,14 @@ type embeddingsProcessorUpstreamFilter[T openai.EmbeddingRequest] struct {
169169
}
170170

171171
// selectTranslator selects the translator based on the output schema.
172-
func (e *embeddingsProcessorUpstreamFilter[T]) selectTranslator(out filterapi.VersionedAPISchema) error {
172+
func (e *embeddingsProcessorUpstreamFilter) selectTranslator(out filterapi.VersionedAPISchema) error {
173173
switch out.Name {
174174
case filterapi.APISchemaOpenAI:
175175
e.translator = translator.NewEmbeddingOpenAIToOpenAITranslator(out.Version, e.modelNameOverride)
176176
case filterapi.APISchemaAzureOpenAI:
177177
e.translator = translator.NewEmbeddingOpenAIToAzureOpenAITranslator(out.Version, e.modelNameOverride)
178178
case filterapi.APISchemaGCPVertexAI:
179-
e.translator = translator.NewEmbeddingOpenAIToAzureOpenAITranslator(out.Version, e.modelNameOverride)
179+
e.translator = translator.NewEmbeddingOpenAIToGCPVertexAITranslator("", e.modelNameOverride)
180180
default:
181181
return fmt.Errorf("unsupported API schema: backend=%s", out)
182182
}
@@ -189,7 +189,7 @@ func (e *embeddingsProcessorUpstreamFilter[T]) selectTranslator(out filterapi.Ve
189189
// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy
190190
// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again
191191
// to the extproc.
192-
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
192+
func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
193193
defer func() {
194194
if err != nil {
195195
e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders)
@@ -204,7 +204,7 @@ func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestHeaders(ctx context
204204
reqModel := cmp.Or(e.requestHeaders[internalapi.ModelNameHeaderKeyDefault], openai.GetModelFromEmbeddingRequest(e.originalRequestBody))
205205
e.metrics.SetRequestModel(reqModel)
206206

207-
newHeaders, newBody, err := e.translator.RequestBody(e.originalRequestBodyRaw, convertToEmbeddingCompletionRequest(e.originalRequestBody), e.onRetry)
207+
newHeaders, newBody, err := e.translator.RequestBody(e.originalRequestBodyRaw, e.originalRequestBody, e.onRetry)
208208
if err != nil {
209209
return nil, fmt.Errorf("failed to transform request: %w", err)
210210
}
@@ -267,12 +267,12 @@ func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestHeaders(ctx context
267267
}
268268

269269
// ProcessRequestBody implements [Processor.ProcessRequestBody].
270-
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
270+
func (e *embeddingsProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
271271
panic("BUG: ProcessRequestBody should not be called in the upstream filter")
272272
}
273273

274274
// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
275-
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseHeaders(ctx context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
275+
func (e *embeddingsProcessorUpstreamFilter) ProcessResponseHeaders(ctx context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
276276
defer func() {
277277
if err != nil {
278278
e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders)
@@ -296,7 +296,7 @@ func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseHeaders(ctx contex
296296
}
297297

298298
// ProcessResponseBody implements [Processor.ProcessResponseBody].
299-
func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
299+
func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
300300
recordRequestCompletionErr := false
301301
defer func() {
302302
if err != nil || recordRequestCompletionErr {
@@ -385,13 +385,13 @@ func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseBody(ctx context.C
385385
}
386386

387387
// SetBackend implements [Processor.SetBackend].
388-
func (e *embeddingsProcessorUpstreamFilter[T]) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) {
388+
func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) {
389389
defer func() {
390390
if err != nil {
391391
e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders)
392392
}
393393
}()
394-
rp, ok := routeProcessor.(*embeddingsProcessorRouterFilter[T])
394+
rp, ok := routeProcessor.(*embeddingsProcessorRouterFilter)
395395
if !ok {
396396
panic("BUG: expected routeProcessor to be of type *embeddingsProcessorRouterFilter")
397397
}
@@ -420,27 +420,26 @@ func (e *embeddingsProcessorUpstreamFilter[T]) SetBackend(ctx context.Context, b
420420
}
421421

422422
// convertToEmbeddingCompletionRequest converts any EmbeddingRequest to EmbeddingCompletionRequest for compatibility
423-
func convertToEmbeddingCompletionRequest[T openai.EmbeddingRequest](req *T) *openai.EmbeddingCompletionRequest {
424-
switch r := any(*req).(type) {
425-
case openai.EmbeddingCompletionRequest:
426-
return &r
427-
case openai.EmbeddingChatRequest:
423+
func convertToEmbeddingCompletionRequest(req *openai.EmbeddingRequest) *openai.EmbeddingCompletionRequest {
424+
if req.OfCompletion != nil {
425+
return req.OfCompletion
426+
} else if req.OfChat != nil {
428427
// Convert EmbeddingChatRequest to EmbeddingCompletionRequest by flattening messages to input
429428
// This is a simplified conversion - in practice you might need more sophisticated logic
430429
return &openai.EmbeddingCompletionRequest{
431-
Model: r.Model,
430+
Model: req.OfChat.Model,
432431
Input: openai.EmbeddingRequestInput{Value: "converted_from_chat"}, // Simplified
433-
EncodingFormat: r.EncodingFormat,
434-
Dimensions: r.Dimensions,
435-
User: r.User,
432+
EncodingFormat: req.OfChat.EncodingFormat,
433+
Dimensions: req.OfChat.Dimensions,
434+
User: req.OfChat.User,
436435
}
437-
default:
436+
} else {
438437
return &openai.EmbeddingCompletionRequest{}
439438
}
440439
}
441440

442-
func parseOpenAIEmbeddingBody[T openai.EmbeddingRequest](body *extprocv3.HttpBody) (modelName string, rb *T, err error) {
443-
var openAIReq T
441+
func parseOpenAIEmbeddingBody(body *extprocv3.HttpBody) (modelName string, rb *openai.EmbeddingRequest, err error) {
442+
var openAIReq openai.EmbeddingRequest
444443
if err := json.Unmarshal(body.Body, &openAIReq); err != nil {
445444
return "", nil, fmt.Errorf("failed to unmarshal body: %w", err)
446445
}

internal/extproc/embeddings_processor_test.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,11 @@ func Test_embeddingsProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T)
371371

372372
func TestEmbeddings_ProcessRequestHeaders_SetsRequestModel(t *testing.T) {
373373
headers := map[string]string{":path": "/v1/embeddings", internalapi.ModelNameHeaderKeyDefault: "header-model"}
374-
body := openai.EmbeddingRequest{Model: "body-model"}
374+
body := openai.EmbeddingRequest{
375+
OfCompletion: &openai.EmbeddingCompletionRequest{
376+
Model: "body-model",
377+
},
378+
}
375379
raw, _ := json.Marshal(body)
376380
mm := &mockMetrics{}
377381
p := &embeddingsProcessorUpstreamFilter{
@@ -395,8 +399,10 @@ func TestEmbeddings_ProcessResponseBody_OverridesHeaderModelWithResponseModel(t
395399
const modelKey = internalapi.ModelNameHeaderKeyDefault
396400
headers := map[string]string{":path": "/v1/embeddings", modelKey: "header-model"}
397401
body := openai.EmbeddingRequest{
398-
Model: "body-model",
399-
Input: openai.EmbeddingRequestInput{Value: "test"},
402+
OfCompletion: &openai.EmbeddingCompletionRequest{
403+
Model: "body-model",
404+
Input: openai.EmbeddingRequestInput{Value: "test"},
405+
},
400406
}
401407
raw, _ := json.Marshal(body)
402408
mm := &mockMetrics{}
@@ -454,8 +460,9 @@ func TestEmbeddings_ParseBody(t *testing.T) {
454460
require.NoError(t, err)
455461
require.Equal(t, "text-embedding-ada-002", modelName)
456462
require.NotNil(t, rb)
457-
require.Equal(t, "text-embedding-ada-002", rb.Model)
458-
require.Equal(t, "test input", rb.Input.Value)
463+
require.NotNil(t, rb.OfCompletion, "should be a completion request")
464+
require.Equal(t, "text-embedding-ada-002", rb.OfCompletion.Model)
465+
require.Equal(t, "test input", rb.OfCompletion.Input.Value)
459466
})
460467
t.Run("error", func(t *testing.T) {
461468
modelName, rb, err := parseOpenAIEmbeddingBody(&extprocv3.HttpBody{})
@@ -680,7 +687,9 @@ func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutatio
680687
}
681688

682689
requestBody := &openai.EmbeddingRequest{
683-
Model: "text-embedding-ada-002",
690+
OfCompletion: &openai.EmbeddingCompletionRequest{
691+
Model: "text-embedding-ada-002",
692+
},
684693
}
685694
requestBodyRaw := []byte(`{"model": "text-embedding-ada-002", "input": "Hello world", "encoding_format": "float", "dimensions": 1536}`)
686695

@@ -757,7 +766,9 @@ func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutatio
757766

758767
originalRequestBodyRaw := []byte(`{"model": "text-embedding-ada-002", "input": "Original input", "encoding_format": "float"}`)
759768
requestBody := &openai.EmbeddingRequest{
760-
Model: "text-embedding-ada-002",
769+
OfCompletion: &openai.EmbeddingCompletionRequest{
770+
Model: "text-embedding-ada-002",
771+
},
761772
}
762773

763774
p := &embeddingsProcessorUpstreamFilter{

0 commit comments

Comments
 (0)