@@ -32,14 +32,14 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
3232 return func (config * filterapi.RuntimeConfig , requestHeaders map [string ]string , logger * slog.Logger , tracing tracing.Tracing , isUpstreamFilter bool ) (Processor , error ) {
3333 logger = logger .With ("processor" , "embeddings" , "isUpstreamFilter" , fmt .Sprintf ("%v" , isUpstreamFilter ))
3434 if ! isUpstreamFilter {
35- return & embeddingsProcessorRouterFilter {
35+ return & embeddingsProcessorRouterFilter [openai. EmbeddingCompletionRequest ] {
3636 config : config ,
3737 tracer : tracing .EmbeddingsTracer (),
3838 requestHeaders : requestHeaders ,
3939 logger : logger ,
4040 }, nil
4141 }
42- return & embeddingsProcessorUpstreamFilter {
42+ return & embeddingsProcessorUpstreamFilter [openai. EmbeddingCompletionRequest ] {
4343 config : config ,
4444 requestHeaders : requestHeaders ,
4545 logger : logger ,
@@ -51,7 +51,7 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
5151// embeddingsProcessorRouterFilter implements [Processor] for the `/v1/embeddings` endpoint.
5252//
5353// This is primarily used to select the route for the request based on the model name.
54- type embeddingsProcessorRouterFilter struct {
54+ type embeddingsProcessorRouterFilter [ T openai. EmbeddingRequest ] struct {
5555 passThroughProcessor
5656 // upstreamFilter is the upstream filter that is used to process the request at the upstream filter.
5757 // This will be updated when the request is retried.
@@ -67,7 +67,7 @@ type embeddingsProcessorRouterFilter struct {
6767 // originalRequestBody is the original request body that is passed to the upstream filter.
6868 // This is used to perform the transformation of the request body on the original input
6969 // when the request is retried.
70- originalRequestBody * openai. EmbeddingRequest
70+ originalRequestBody * T
7171 originalRequestBodyRaw []byte
7272 // tracer is the tracer used for requests.
7373 tracer tracing.EmbeddingsTracer
@@ -79,7 +79,7 @@ type embeddingsProcessorRouterFilter struct {
7979}
8080
8181// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
82- func (e * embeddingsProcessorRouterFilter ) ProcessResponseHeaders (ctx context.Context , headerMap * corev3.HeaderMap ) (* extprocv3.ProcessingResponse , error ) {
82+ func (e * embeddingsProcessorRouterFilter [ T ] ) ProcessResponseHeaders (ctx context.Context , headerMap * corev3.HeaderMap ) (* extprocv3.ProcessingResponse , error ) {
8383 // If the request failed to route and/or immediate response was returned before the upstream filter was set,
8484 // e.upstreamFilter can be nil.
8585 if e .upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
@@ -89,7 +89,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessResponseHeaders(ctx context.Con
8989}
9090
9191// ProcessResponseBody implements [Processor.ProcessResponseBody].
92- func (e * embeddingsProcessorRouterFilter ) ProcessResponseBody (ctx context.Context , body * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
92+ func (e * embeddingsProcessorRouterFilter [ T ] ) ProcessResponseBody (ctx context.Context , body * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
9393 // If the request failed to route and/or immediate response was returned before the upstream filter was set,
9494 // e.upstreamFilter can be nil.
9595 if e .upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
@@ -99,8 +99,8 @@ func (e *embeddingsProcessorRouterFilter) ProcessResponseBody(ctx context.Contex
9999}
100100
101101// ProcessRequestBody implements [Processor.ProcessRequestBody].
102- func (e * embeddingsProcessorRouterFilter ) ProcessRequestBody (ctx context.Context , rawBody * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
103- originalModel , body , err := parseOpenAIEmbeddingBody (rawBody )
102+ func (e * embeddingsProcessorRouterFilter [ T ] ) ProcessRequestBody (ctx context.Context , rawBody * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
103+ originalModel , body , err := parseOpenAIEmbeddingBody [ T ] (rawBody )
104104 if err != nil {
105105 return nil , fmt .Errorf ("failed to parse request body: %w" , err )
106106 }
@@ -125,7 +125,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context
125125 ctx ,
126126 e .requestHeaders ,
127127 & headerMutationCarrier {m : headerMutation },
128- body ,
128+ convertToEmbeddingCompletionRequest ( body ) ,
129129 rawBody .Body ,
130130 )
131131
@@ -144,7 +144,7 @@ func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context
144144// embeddingsProcessorUpstreamFilter implements [Processor] for the `/v1/embeddings` endpoint at the upstream filter.
145145//
146146// This is created per retry and handles the translation as well as the authentication of the request.
147- type embeddingsProcessorUpstreamFilter struct {
147+ type embeddingsProcessorUpstreamFilter [ T openai. EmbeddingRequest ] struct {
148148 logger * slog.Logger
149149 config * filterapi.RuntimeConfig
150150 requestHeaders map [string ]string
@@ -156,7 +156,7 @@ type embeddingsProcessorUpstreamFilter struct {
156156 headerMutator * headermutator.HeaderMutator
157157 bodyMutator * bodymutator.BodyMutator
158158 originalRequestBodyRaw []byte
159- originalRequestBody * openai. EmbeddingRequest
159+ originalRequestBody * T
160160 translator translator.OpenAIEmbeddingTranslator
161161 // onRetry is true if this is a retry request at the upstream filter.
162162 onRetry bool
@@ -169,12 +169,14 @@ type embeddingsProcessorUpstreamFilter struct {
169169}
170170
171171// selectTranslator selects the translator based on the output schema.
172- func (e * embeddingsProcessorUpstreamFilter ) selectTranslator (out filterapi.VersionedAPISchema ) error {
172+ func (e * embeddingsProcessorUpstreamFilter [ T ] ) selectTranslator (out filterapi.VersionedAPISchema ) error {
173173 switch out .Name {
174174 case filterapi .APISchemaOpenAI :
175175 e .translator = translator .NewEmbeddingOpenAIToOpenAITranslator (out .Version , e .modelNameOverride )
176176 case filterapi .APISchemaAzureOpenAI :
177177 e .translator = translator .NewEmbeddingOpenAIToAzureOpenAITranslator (out .Version , e .modelNameOverride )
178+ case filterapi .APISchemaGCPVertexAI :
179+ e .translator = translator .NewEmbeddingOpenAIToAzureOpenAITranslator (out .Version , e .modelNameOverride )
178180 default :
179181 return fmt .Errorf ("unsupported API schema: backend=%s" , out )
180182 }
@@ -187,7 +189,7 @@ func (e *embeddingsProcessorUpstreamFilter) selectTranslator(out filterapi.Versi
187189// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy
188190// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again
189191// to the extproc.
190- func (e * embeddingsProcessorUpstreamFilter ) ProcessRequestHeaders (ctx context.Context , _ * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
192+ func (e * embeddingsProcessorUpstreamFilter [ T ] ) ProcessRequestHeaders (ctx context.Context , _ * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
191193 defer func () {
192194 if err != nil {
193195 e .metrics .RecordRequestCompletion (ctx , false , e .requestHeaders )
@@ -197,12 +199,12 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co
197199 // Start tracking metrics for this request.
198200 e .metrics .StartRequest (e .requestHeaders )
199201 // Set the original model from the request body before any overrides
200- e .metrics .SetOriginalModel (e .originalRequestBody . Model )
202+ e .metrics .SetOriginalModel (openai . GetModelFromEmbeddingRequest ( e .originalRequestBody ) )
201203 // Set the request model for metrics from the original model or override if applied.
202- reqModel := cmp .Or (e .requestHeaders [internalapi .ModelNameHeaderKeyDefault ], e .originalRequestBody . Model )
204+ reqModel := cmp .Or (e .requestHeaders [internalapi .ModelNameHeaderKeyDefault ], openai . GetModelFromEmbeddingRequest ( e .originalRequestBody ) )
203205 e .metrics .SetRequestModel (reqModel )
204206
205- newHeaders , newBody , err := e .translator .RequestBody (e .originalRequestBodyRaw , e .originalRequestBody , e .onRetry )
207+ newHeaders , newBody , err := e .translator .RequestBody (e .originalRequestBodyRaw , convertToEmbeddingCompletionRequest ( e .originalRequestBody ) , e .onRetry )
206208 if err != nil {
207209 return nil , fmt .Errorf ("failed to transform request: %w" , err )
208210 }
@@ -265,12 +267,12 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co
265267}
266268
267269// ProcessRequestBody implements [Processor.ProcessRequestBody].
268- func (e * embeddingsProcessorUpstreamFilter ) ProcessRequestBody (context.Context , * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
270+ func (e * embeddingsProcessorUpstreamFilter [ T ] ) ProcessRequestBody (context.Context , * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
269271 panic ("BUG: ProcessRequestBody should not be called in the upstream filter" )
270272}
271273
272274// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
273- func (e * embeddingsProcessorUpstreamFilter ) ProcessResponseHeaders (ctx context.Context , headers * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
275+ func (e * embeddingsProcessorUpstreamFilter [ T ] ) ProcessResponseHeaders (ctx context.Context , headers * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
274276 defer func () {
275277 if err != nil {
276278 e .metrics .RecordRequestCompletion (ctx , false , e .requestHeaders )
@@ -294,7 +296,7 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseHeaders(ctx context.C
294296}
295297
296298// ProcessResponseBody implements [Processor.ProcessResponseBody].
297- func (e * embeddingsProcessorUpstreamFilter ) ProcessResponseBody (ctx context.Context , body * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
299+ func (e * embeddingsProcessorUpstreamFilter [ T ] ) ProcessResponseBody (ctx context.Context , body * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
298300 recordRequestCompletionErr := false
299301 defer func () {
300302 if err != nil || recordRequestCompletionErr {
@@ -383,13 +385,13 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Cont
383385}
384386
385387// SetBackend implements [Processor.SetBackend].
386- func (e * embeddingsProcessorUpstreamFilter ) SetBackend (ctx context.Context , b * filterapi.Backend , backendHandler filterapi.BackendAuthHandler , routeProcessor Processor ) (err error ) {
388+ func (e * embeddingsProcessorUpstreamFilter [ T ] ) SetBackend (ctx context.Context , b * filterapi.Backend , backendHandler filterapi.BackendAuthHandler , routeProcessor Processor ) (err error ) {
387389 defer func () {
388390 if err != nil {
389391 e .metrics .RecordRequestCompletion (ctx , false , e .requestHeaders )
390392 }
391393 }()
392- rp , ok := routeProcessor .(* embeddingsProcessorRouterFilter )
394+ rp , ok := routeProcessor .(* embeddingsProcessorRouterFilter [ T ] )
393395 if ! ok {
394396 panic ("BUG: expected routeProcessor to be of type *embeddingsProcessorRouterFilter" )
395397 }
@@ -417,10 +419,31 @@ func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *f
417419 return
418420}
419421
420- func parseOpenAIEmbeddingBody (body * extprocv3.HttpBody ) (modelName string , rb * openai.EmbeddingRequest , err error ) {
421- var openAIReq openai.EmbeddingRequest
422+ // convertToEmbeddingCompletionRequest converts any EmbeddingRequest to EmbeddingCompletionRequest for compatibility
423+ func convertToEmbeddingCompletionRequest [T openai.EmbeddingRequest ](req * T ) * openai.EmbeddingCompletionRequest {
424+ switch r := any (* req ).(type ) {
425+ case openai.EmbeddingCompletionRequest :
426+ return & r
427+ case openai.EmbeddingChatRequest :
428+ // Convert EmbeddingChatRequest to EmbeddingCompletionRequest by flattening messages to input
429+ // This is a simplified conversion - in practice you might need more sophisticated logic
430+ return & openai.EmbeddingCompletionRequest {
431+ Model : r .Model ,
432+ Input : openai.EmbeddingRequestInput {Value : "converted_from_chat" }, // Simplified
433+ EncodingFormat : r .EncodingFormat ,
434+ Dimensions : r .Dimensions ,
435+ User : r .User ,
436+ }
437+ default :
438+ return & openai.EmbeddingCompletionRequest {}
439+ }
440+ }
441+
442+ func parseOpenAIEmbeddingBody [T openai.EmbeddingRequest ](body * extprocv3.HttpBody ) (modelName string , rb * T , err error ) {
443+ var openAIReq T
422444 if err := json .Unmarshal (body .Body , & openAIReq ); err != nil {
423445 return "" , nil , fmt .Errorf ("failed to unmarshal body: %w" , err )
424446 }
425- return openAIReq .Model , & openAIReq , nil
447+
448+ return openai .GetModelFromEmbeddingRequest (& openAIReq ), & openAIReq , nil
426449}
0 commit comments