@@ -32,14 +32,14 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
3232 return func (config * filterapi.RuntimeConfig , requestHeaders map [string ]string , logger * slog.Logger , tracing tracing.Tracing , isUpstreamFilter bool ) (Processor , error ) {
3333 logger = logger .With ("processor" , "embeddings" , "isUpstreamFilter" , fmt .Sprintf ("%v" , isUpstreamFilter ))
3434 if ! isUpstreamFilter {
35- return & embeddingsProcessorRouterFilter [openai. EmbeddingCompletionRequest ] {
35+ return & embeddingsProcessorRouterFilter {
3636 config : config ,
3737 tracer : tracing .EmbeddingsTracer (),
3838 requestHeaders : requestHeaders ,
3939 logger : logger ,
4040 }, nil
4141 }
42- return & embeddingsProcessorUpstreamFilter [openai. EmbeddingCompletionRequest ] {
42+ return & embeddingsProcessorUpstreamFilter {
4343 config : config ,
4444 requestHeaders : requestHeaders ,
4545 logger : logger ,
@@ -51,7 +51,7 @@ func EmbeddingsProcessorFactory(f metrics.Factory) ProcessorFactory {
5151// embeddingsProcessorRouterFilter implements [Processor] for the `/v1/embeddings` endpoint.
5252//
5353// This is primarily used to select the route for the request based on the model name.
54- type embeddingsProcessorRouterFilter [ T openai. EmbeddingRequest ] struct {
54+ type embeddingsProcessorRouterFilter struct {
5555 passThroughProcessor
5656 // upstreamFilter is the upstream filter that is used to process the request at the upstream filter.
5757 // This will be updated when the request is retried.
@@ -67,7 +67,7 @@ type embeddingsProcessorRouterFilter[T openai.EmbeddingRequest] struct {
6767 // originalRequestBody is the original request body that is passed to the upstream filter.
6868 // This is used to perform the transformation of the request body on the original input
6969 // when the request is retried.
70- originalRequestBody * T
70+ originalRequestBody * openai. EmbeddingRequest
7171 originalRequestBodyRaw []byte
7272 // tracer is the tracer used for requests.
7373 tracer tracing.EmbeddingsTracer
@@ -79,7 +79,7 @@ type embeddingsProcessorRouterFilter[T openai.EmbeddingRequest] struct {
7979}
8080
8181// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
82- func (e * embeddingsProcessorRouterFilter [ T ] ) ProcessResponseHeaders (ctx context.Context , headerMap * corev3.HeaderMap ) (* extprocv3.ProcessingResponse , error ) {
82+ func (e * embeddingsProcessorRouterFilter ) ProcessResponseHeaders (ctx context.Context , headerMap * corev3.HeaderMap ) (* extprocv3.ProcessingResponse , error ) {
8383 // If the request failed to route and/or immediate response was returned before the upstream filter was set,
8484 // e.upstreamFilter can be nil.
8585 if e .upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
@@ -89,7 +89,7 @@ func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseHeaders(ctx context.
8989}
9090
9191// ProcessResponseBody implements [Processor.ProcessResponseBody].
92- func (e * embeddingsProcessorRouterFilter [ T ] ) ProcessResponseBody (ctx context.Context , body * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
92+ func (e * embeddingsProcessorRouterFilter ) ProcessResponseBody (ctx context.Context , body * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
9393 // If the request failed to route and/or immediate response was returned before the upstream filter was set,
9494 // e.upstreamFilter can be nil.
9595 if e .upstreamFilter != nil { // See the comment on the "upstreamFilter" field.
@@ -99,8 +99,8 @@ func (e *embeddingsProcessorRouterFilter[T]) ProcessResponseBody(ctx context.Con
9999}
100100
101101// ProcessRequestBody implements [Processor.ProcessRequestBody].
102- func (e * embeddingsProcessorRouterFilter [ T ] ) ProcessRequestBody (ctx context.Context , rawBody * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
103- originalModel , body , err := parseOpenAIEmbeddingBody [ T ] (rawBody )
102+ func (e * embeddingsProcessorRouterFilter ) ProcessRequestBody (ctx context.Context , rawBody * extprocv3.HttpBody ) (* extprocv3.ProcessingResponse , error ) {
103+ originalModel , body , err := parseOpenAIEmbeddingBody (rawBody )
104104 if err != nil {
105105 return nil , fmt .Errorf ("failed to parse request body: %w" , err )
106106 }
@@ -125,7 +125,7 @@ func (e *embeddingsProcessorRouterFilter[T]) ProcessRequestBody(ctx context.Cont
125125 ctx ,
126126 e .requestHeaders ,
127127 & headerMutationCarrier {m : headerMutation },
128- convertToEmbeddingCompletionRequest ( body ) ,
128+ body ,
129129 rawBody .Body ,
130130 )
131131
@@ -144,7 +144,7 @@ func (e *embeddingsProcessorRouterFilter[T]) ProcessRequestBody(ctx context.Cont
144144// embeddingsProcessorUpstreamFilter implements [Processor] for the `/v1/embeddings` endpoint at the upstream filter.
145145//
146146// This is created per retry and handles the translation as well as the authentication of the request.
147- type embeddingsProcessorUpstreamFilter [ T openai. EmbeddingRequest ] struct {
147+ type embeddingsProcessorUpstreamFilter struct {
148148 logger * slog.Logger
149149 config * filterapi.RuntimeConfig
150150 requestHeaders map [string ]string
@@ -156,7 +156,7 @@ type embeddingsProcessorUpstreamFilter[T openai.EmbeddingRequest] struct {
156156 headerMutator * headermutator.HeaderMutator
157157 bodyMutator * bodymutator.BodyMutator
158158 originalRequestBodyRaw []byte
159- originalRequestBody * T
159+ originalRequestBody * openai. EmbeddingRequest
160160 translator translator.OpenAIEmbeddingTranslator
161161 // onRetry is true if this is a retry request at the upstream filter.
162162 onRetry bool
@@ -169,14 +169,14 @@ type embeddingsProcessorUpstreamFilter[T openai.EmbeddingRequest] struct {
169169}
170170
171171// selectTranslator selects the translator based on the output schema.
172- func (e * embeddingsProcessorUpstreamFilter [ T ] ) selectTranslator (out filterapi.VersionedAPISchema ) error {
172+ func (e * embeddingsProcessorUpstreamFilter ) selectTranslator (out filterapi.VersionedAPISchema ) error {
173173 switch out .Name {
174174 case filterapi .APISchemaOpenAI :
175175 e .translator = translator .NewEmbeddingOpenAIToOpenAITranslator (out .Version , e .modelNameOverride )
176176 case filterapi .APISchemaAzureOpenAI :
177177 e .translator = translator .NewEmbeddingOpenAIToAzureOpenAITranslator (out .Version , e .modelNameOverride )
178178 case filterapi .APISchemaGCPVertexAI :
179- e .translator = translator .NewEmbeddingOpenAIToAzureOpenAITranslator ( out . Version , e .modelNameOverride )
179+ e .translator = translator .NewEmbeddingOpenAIToGCPVertexAITranslator ( "" , e .modelNameOverride )
180180 default :
181181 return fmt .Errorf ("unsupported API schema: backend=%s" , out )
182182 }
@@ -189,7 +189,7 @@ func (e *embeddingsProcessorUpstreamFilter[T]) selectTranslator(out filterapi.Ve
189189// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy
190190// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again
191191// to the extproc.
192- func (e * embeddingsProcessorUpstreamFilter [ T ] ) ProcessRequestHeaders (ctx context.Context , _ * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
192+ func (e * embeddingsProcessorUpstreamFilter ) ProcessRequestHeaders (ctx context.Context , _ * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
193193 defer func () {
194194 if err != nil {
195195 e .metrics .RecordRequestCompletion (ctx , false , e .requestHeaders )
@@ -204,7 +204,7 @@ func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestHeaders(ctx context
204204 reqModel := cmp .Or (e .requestHeaders [internalapi .ModelNameHeaderKeyDefault ], openai .GetModelFromEmbeddingRequest (e .originalRequestBody ))
205205 e .metrics .SetRequestModel (reqModel )
206206
207- newHeaders , newBody , err := e .translator .RequestBody (e .originalRequestBodyRaw , convertToEmbeddingCompletionRequest ( e .originalRequestBody ) , e .onRetry )
207+ newHeaders , newBody , err := e .translator .RequestBody (e .originalRequestBodyRaw , e .originalRequestBody , e .onRetry )
208208 if err != nil {
209209 return nil , fmt .Errorf ("failed to transform request: %w" , err )
210210 }
@@ -267,12 +267,12 @@ func (e *embeddingsProcessorUpstreamFilter[T]) ProcessRequestHeaders(ctx context
267267}
268268
269269// ProcessRequestBody implements [Processor.ProcessRequestBody].
270- func (e * embeddingsProcessorUpstreamFilter [ T ] ) ProcessRequestBody (context.Context , * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
270+ func (e * embeddingsProcessorUpstreamFilter ) ProcessRequestBody (context.Context , * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
271271 panic ("BUG: ProcessRequestBody should not be called in the upstream filter" )
272272}
273273
274274// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
275- func (e * embeddingsProcessorUpstreamFilter [ T ] ) ProcessResponseHeaders (ctx context.Context , headers * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
275+ func (e * embeddingsProcessorUpstreamFilter ) ProcessResponseHeaders (ctx context.Context , headers * corev3.HeaderMap ) (res * extprocv3.ProcessingResponse , err error ) {
276276 defer func () {
277277 if err != nil {
278278 e .metrics .RecordRequestCompletion (ctx , false , e .requestHeaders )
@@ -296,7 +296,7 @@ func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseHeaders(ctx contex
296296}
297297
298298// ProcessResponseBody implements [Processor.ProcessResponseBody].
299- func (e * embeddingsProcessorUpstreamFilter [ T ] ) ProcessResponseBody (ctx context.Context , body * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
299+ func (e * embeddingsProcessorUpstreamFilter ) ProcessResponseBody (ctx context.Context , body * extprocv3.HttpBody ) (res * extprocv3.ProcessingResponse , err error ) {
300300 recordRequestCompletionErr := false
301301 defer func () {
302302 if err != nil || recordRequestCompletionErr {
@@ -385,13 +385,13 @@ func (e *embeddingsProcessorUpstreamFilter[T]) ProcessResponseBody(ctx context.C
385385}
386386
387387// SetBackend implements [Processor.SetBackend].
388- func (e * embeddingsProcessorUpstreamFilter [ T ] ) SetBackend (ctx context.Context , b * filterapi.Backend , backendHandler filterapi.BackendAuthHandler , routeProcessor Processor ) (err error ) {
388+ func (e * embeddingsProcessorUpstreamFilter ) SetBackend (ctx context.Context , b * filterapi.Backend , backendHandler filterapi.BackendAuthHandler , routeProcessor Processor ) (err error ) {
389389 defer func () {
390390 if err != nil {
391391 e .metrics .RecordRequestCompletion (ctx , false , e .requestHeaders )
392392 }
393393 }()
394- rp , ok := routeProcessor .(* embeddingsProcessorRouterFilter [ T ] )
394+ rp , ok := routeProcessor .(* embeddingsProcessorRouterFilter )
395395 if ! ok {
396396 panic ("BUG: expected routeProcessor to be of type *embeddingsProcessorRouterFilter" )
397397 }
@@ -420,27 +420,26 @@ func (e *embeddingsProcessorUpstreamFilter[T]) SetBackend(ctx context.Context, b
420420}
421421
422422// convertToEmbeddingCompletionRequest converts any EmbeddingRequest to EmbeddingCompletionRequest for compatibility
423- func convertToEmbeddingCompletionRequest [T openai.EmbeddingRequest ](req * T ) * openai.EmbeddingCompletionRequest {
424- switch r := any (* req ).(type ) {
425- case openai.EmbeddingCompletionRequest :
426- return & r
427- case openai.EmbeddingChatRequest :
423+ func convertToEmbeddingCompletionRequest (req * openai.EmbeddingRequest ) * openai.EmbeddingCompletionRequest {
424+ if req .OfCompletion != nil {
425+ return req .OfCompletion
426+ } else if req .OfChat != nil {
428427 // Convert EmbeddingChatRequest to EmbeddingCompletionRequest by flattening messages to input
429428 // This is a simplified conversion - in practice you might need more sophisticated logic
430429 return & openai.EmbeddingCompletionRequest {
431- Model : r .Model ,
430+ Model : req . OfChat .Model ,
432431 Input : openai.EmbeddingRequestInput {Value : "converted_from_chat" }, // Simplified
433- EncodingFormat : r .EncodingFormat ,
434- Dimensions : r .Dimensions ,
435- User : r .User ,
432+ EncodingFormat : req . OfChat .EncodingFormat ,
433+ Dimensions : req . OfChat .Dimensions ,
434+ User : req . OfChat .User ,
436435 }
437- default :
436+ } else {
438437 return & openai.EmbeddingCompletionRequest {}
439438 }
440439}
441440
442- func parseOpenAIEmbeddingBody [ T openai. EmbeddingRequest ] (body * extprocv3.HttpBody ) (modelName string , rb * T , err error ) {
443- var openAIReq T
441+ func parseOpenAIEmbeddingBody (body * extprocv3.HttpBody ) (modelName string , rb * openai. EmbeddingRequest , err error ) {
442+ var openAIReq openai. EmbeddingRequest
444443 if err := json .Unmarshal (body .Body , & openAIReq ); err != nil {
445444 return "" , nil , fmt .Errorf ("failed to unmarshal body: %w" , err )
446445 }
0 commit comments