Skip to content

Commit 6d7059c

Browse files
mathetakenacx
andauthored
refactor: consolidates all processors with one generic type (#1600)
**Description** This consolidates all the copy-pasted processors that existed per endpoint we support into one generic processor. This was made possible thanks to the series of refactoring that we landed in the past few weeks primarily for dynamic module work #90. Notably, now in order to add an endpoint, majority of the new code will be in translator (where also have shared generic interface) as well as the type definition. No longer it requires the huge copy paste of processors. **Related Issues/PRs (if applicable)** Resolves #1083 Blocker for #1429 #1584 #1592 #1594 --------- Signed-off-by: Takeshi Yoneda <[email protected]> Co-authored-by: Ignasi Barrera <[email protected]>
1 parent 6f312e5 commit 6d7059c

16 files changed

+905
-7224
lines changed

cmd/extproc/mainlib/main.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"google.golang.org/grpc"
2626
"google.golang.org/grpc/health/grpc_health_v1"
2727

28+
"github.com/envoyproxy/ai-gateway/internal/endpointspec"
2829
"github.com/envoyproxy/ai-gateway/internal/extproc"
2930
"github.com/envoyproxy/ai-gateway/internal/filterapi"
3031
"github.com/envoyproxy/ai-gateway/internal/internalapi"
@@ -254,13 +255,19 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
254255
if err != nil {
255256
return fmt.Errorf("failed to create external processor server: %w", err)
256257
}
257-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/chat/completions"), extproc.ChatCompletionProcessorFactory(chatCompletionMetricsFactory, tracing.ChatCompletionTracer()))
258-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/completions"), extproc.CompletionsProcessorFactory(completionMetricsFactory, tracing.CompletionTracer()))
259-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/embeddings"), extproc.EmbeddingsProcessorFactory(embeddingsMetricsFactory, tracing.EmbeddingsTracer()))
260-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/images/generations"), extproc.ImageGenerationProcessorFactory(imageGenerationMetricsFactory, tracing.ImageGenerationTracer()))
261-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Cohere, "/v2/rerank"), extproc.RerankProcessorFactory(rerankMetricsFactory, tracing.RerankTracer()))
258+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/chat/completions"), extproc.NewFactory(
259+
chatCompletionMetricsFactory, tracing.ChatCompletionTracer(), endpointspec.ChatCompletionsEndpointSpec{}))
260+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/completions"), extproc.NewFactory(
261+
completionMetricsFactory, tracing.CompletionTracer(), endpointspec.CompletionsEndpointSpec{}))
262+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/embeddings"), extproc.NewFactory(
263+
embeddingsMetricsFactory, tracing.EmbeddingsTracer(), endpointspec.EmbeddingsEndpointSpec{}))
264+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/images/generations"), extproc.NewFactory(
265+
imageGenerationMetricsFactory, tracing.ImageGenerationTracer(), endpointspec.ImageGenerationEndpointSpec{}))
266+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Cohere, "/v2/rerank"), extproc.NewFactory(
267+
rerankMetricsFactory, tracing.RerankTracer(), endpointspec.RerankEndpointSpec{}))
262268
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/models"), extproc.NewModelsProcessor)
263-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Anthropic, "/v1/messages"), extproc.MessagesProcessorFactory(messagesMetricsFactory, tracing.MessageTracer()))
269+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Anthropic, "/v1/messages"), extproc.NewFactory(
270+
messagesMetricsFactory, tracing.MessageTracer(), endpointspec.MessagesEndpointSpec{}))
264271

265272
if watchErr := filterapi.StartConfigWatcher(ctx, flags.configPath, server, l, time.Second*5); watchErr != nil {
266273
return fmt.Errorf("failed to start config watcher: %w", watchErr)
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// Copyright Envoy AI Gateway Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
// The full text of the Apache license is available in the LICENSE file at
4+
// the root of the repo.
5+
6+
// Package endpointspec defines the EndpointSpec which is to bundle the translator, tracing
7+
// and most importantly request and response types for different API endpoints.
8+
package endpointspec
9+
10+
import (
11+
"encoding/json"
12+
"fmt"
13+
14+
openaisdk "github.com/openai/openai-go/v2"
15+
"github.com/tidwall/sjson"
16+
17+
"github.com/envoyproxy/ai-gateway/internal/apischema/anthropic"
18+
cohereschema "github.com/envoyproxy/ai-gateway/internal/apischema/cohere"
19+
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
20+
"github.com/envoyproxy/ai-gateway/internal/filterapi"
21+
"github.com/envoyproxy/ai-gateway/internal/internalapi"
22+
tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api"
23+
"github.com/envoyproxy/ai-gateway/internal/translator"
24+
)
25+
26+
type (
27+
// Spec defines methods for parsing request bodies and selecting translators
28+
// for different API endpoints.
29+
//
30+
// Type Parameters:
31+
// * ReqT: The request type.
32+
// * RespT: The response type.
33+
// * RespChunkT: The chunk type for streaming responses.
34+
//
35+
// This must be implemented by specific endpoint handlers to provide
36+
// custom logic for parsing and translation.
37+
Spec[ReqT, RespT, RespChunkT any] interface {
38+
// ParseBody parses the request body and returns the original model,
39+
// the parsed request, whether the request is streaming, any mutated body,
40+
// and an error if parsing fails.
41+
//
42+
// Parameters:
43+
// * body: The raw request body as a byte slice.
44+
// * costConfigured: A boolean indicating if cost metrics are configured.
45+
//
46+
// Returns:
47+
// * originalModel: The original model specified in the request.
48+
// * req: The parsed request of type ReqT.
49+
// * stream: A boolean indicating if the request is for streaming responses.
50+
// * mutatedBody: The possibly mutated request body as a byte slice. Or nil if no mutation is needed.
51+
// * err: An error if parsing fails.
52+
ParseBody(body []byte, costConfigured bool) (originalModel internalapi.OriginalModel, req *ReqT, stream bool, mutatedBody []byte, err error)
53+
// GetTranslator selects the appropriate translator based on the output API schema
54+
// and an optional model name override.
55+
//
56+
// Parameters:
57+
// * out: The output API schema for which the translator is needed.
58+
// * modelNameOverride: An optional model name to override the one specified in the request.
59+
//
60+
// Returns:
61+
// * translator: The selected translator of type Translator[ReqT, RespT, RespChunkT].
62+
// * err: An error if translator selection fails.
63+
GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.Translator[ReqT, tracing.Span[RespT, RespChunkT]], error)
64+
}
65+
// ChatCompletionsEndpointSpec implements EndpointSpec for /v1/chat/completions.
66+
ChatCompletionsEndpointSpec struct{}
67+
// CompletionsEndpointSpec implements EndpointSpec for /v1/completions.
68+
CompletionsEndpointSpec struct{}
69+
// EmbeddingsEndpointSpec implements EndpointSpec for /v1/embeddings.
70+
EmbeddingsEndpointSpec struct{}
71+
// ImageGenerationEndpointSpec implements EndpointSpec for /v1/images/generations.
72+
ImageGenerationEndpointSpec struct{}
73+
// MessagesEndpointSpec implements EndpointSpec for /v1/messages.
74+
MessagesEndpointSpec struct{}
75+
// RerankEndpointSpec implements EndpointSpec for /v2/rerank.
76+
RerankEndpointSpec struct{}
77+
)
78+
79+
// ParseBody implements [EndpointSpec.ParseBody].
80+
func (ChatCompletionsEndpointSpec) ParseBody(
81+
body []byte,
82+
costConfigured bool,
83+
) (internalapi.OriginalModel, *openai.ChatCompletionRequest, bool, []byte, error) {
84+
var req openai.ChatCompletionRequest
85+
if err := json.Unmarshal(body, &req); err != nil {
86+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal chat completion request: %w", err)
87+
}
88+
var mutatedBody []byte
89+
if req.Stream && costConfigured && (req.StreamOptions == nil || !req.StreamOptions.IncludeUsage) {
90+
// If the request is a streaming request and cost metrics are configured, we need to include usage in the response
91+
// to avoid the bypassing of the token usage calculation.
92+
req.StreamOptions = &openai.StreamOptions{IncludeUsage: true}
93+
// Rewrite the original bytes to include the stream_options.include_usage=true so that forcing the request body
94+
// mutation, which uses this raw body, will also result in the stream_options.include_usage=true.
95+
var err error
96+
mutatedBody, err = sjson.SetBytesOptions(body, "stream_options.include_usage", true, &sjson.Options{
97+
Optimistic: true,
98+
// Note: it is safe to do in-place replacement since this route level processor is executed once per request,
99+
// and the result can be safely shared among possible multiple retries.
100+
ReplaceInPlace: true,
101+
})
102+
if err != nil {
103+
return "", nil, false, nil, fmt.Errorf("failed to set stream_options: %w", err)
104+
}
105+
}
106+
return req.Model, &req, req.Stream, mutatedBody, nil
107+
}
108+
109+
// GetTranslator implements [EndpointSpec.GetTranslator].
110+
func (ChatCompletionsEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.OpenAIChatCompletionTranslator, error) {
111+
switch schema.Name {
112+
case filterapi.APISchemaOpenAI:
113+
return translator.NewChatCompletionOpenAIToOpenAITranslator(schema.Version, modelNameOverride), nil
114+
case filterapi.APISchemaAWSBedrock:
115+
return translator.NewChatCompletionOpenAIToAWSBedrockTranslator(modelNameOverride), nil
116+
case filterapi.APISchemaAzureOpenAI:
117+
return translator.NewChatCompletionOpenAIToAzureOpenAITranslator(schema.Version, modelNameOverride), nil
118+
case filterapi.APISchemaGCPVertexAI:
119+
return translator.NewChatCompletionOpenAIToGCPVertexAITranslator(modelNameOverride), nil
120+
case filterapi.APISchemaGCPAnthropic:
121+
return translator.NewChatCompletionOpenAIToGCPAnthropicTranslator(schema.Version, modelNameOverride), nil
122+
default:
123+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
124+
}
125+
}
126+
127+
// ParseBody implements [EndpointSpec.ParseBody].
128+
func (CompletionsEndpointSpec) ParseBody(
129+
body []byte,
130+
_ bool,
131+
) (internalapi.OriginalModel, *openai.CompletionRequest, bool, []byte, error) {
132+
var openAIReq openai.CompletionRequest
133+
if err := json.Unmarshal(body, &openAIReq); err != nil {
134+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal completion request: %w", err)
135+
}
136+
return openAIReq.Model, &openAIReq, openAIReq.Stream, nil, nil
137+
}
138+
139+
// GetTranslator implements [EndpointSpec.GetTranslator].
140+
func (CompletionsEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.OpenAICompletionTranslator, error) {
141+
switch schema.Name {
142+
case filterapi.APISchemaOpenAI:
143+
return translator.NewCompletionOpenAIToOpenAITranslator(schema.Version, modelNameOverride), nil
144+
default:
145+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
146+
}
147+
}
148+
149+
// ParseBody implements [EndpointSpec.ParseBody].
150+
func (EmbeddingsEndpointSpec) ParseBody(
151+
body []byte,
152+
_ bool,
153+
) (internalapi.OriginalModel, *openai.EmbeddingRequest, bool, []byte, error) {
154+
var openAIReq openai.EmbeddingRequest
155+
if err := json.Unmarshal(body, &openAIReq); err != nil {
156+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal embedding request: %w", err)
157+
}
158+
return openAIReq.Model, &openAIReq, false, nil, nil
159+
}
160+
161+
// GetTranslator implements [EndpointSpec.GetTranslator].
162+
func (EmbeddingsEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.OpenAIEmbeddingTranslator, error) {
163+
switch schema.Name {
164+
case filterapi.APISchemaOpenAI:
165+
return translator.NewEmbeddingOpenAIToOpenAITranslator(schema.Version, modelNameOverride), nil
166+
case filterapi.APISchemaAzureOpenAI:
167+
return translator.NewEmbeddingOpenAIToAzureOpenAITranslator(schema.Version, modelNameOverride), nil
168+
default:
169+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
170+
}
171+
}
172+
173+
func (ImageGenerationEndpointSpec) ParseBody(
174+
body []byte,
175+
_ bool,
176+
) (internalapi.OriginalModel, *openaisdk.ImageGenerateParams, bool, []byte, error) {
177+
var openAIReq openaisdk.ImageGenerateParams
178+
if err := json.Unmarshal(body, &openAIReq); err != nil {
179+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal image generation request: %w", err)
180+
}
181+
return openAIReq.Model, &openAIReq, false, nil, nil
182+
}
183+
184+
// GetTranslator implements [EndpointSpec.GetTranslator].
185+
func (ImageGenerationEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.OpenAIImageGenerationTranslator, error) {
186+
switch schema.Name {
187+
case filterapi.APISchemaOpenAI:
188+
return translator.NewImageGenerationOpenAIToOpenAITranslator(schema.Version, modelNameOverride), nil
189+
default:
190+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
191+
}
192+
}
193+
194+
// ParseBody implements [EndpointSpec.ParseBody].
195+
func (MessagesEndpointSpec) ParseBody(
196+
body []byte,
197+
_ bool,
198+
) (internalapi.OriginalModel, *anthropic.MessagesRequest, bool, []byte, error) {
199+
var anthropicReq anthropic.MessagesRequest
200+
if err := json.Unmarshal(body, &anthropicReq); err != nil {
201+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal Anthropic Messages body: %w", err)
202+
}
203+
204+
model := anthropicReq.GetModel()
205+
if model == "" {
206+
return "", nil, false, nil, fmt.Errorf("model field is required in Anthropic request")
207+
}
208+
209+
stream := anthropicReq.GetStream()
210+
return model, &anthropicReq, stream, nil, nil
211+
}
212+
213+
// GetTranslator implements [EndpointSpec.GetTranslator].
214+
func (MessagesEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.AnthropicMessagesTranslator, error) {
215+
// Messages processor only supports Anthropic-native translators.
216+
switch schema.Name {
217+
case filterapi.APISchemaGCPAnthropic:
218+
return translator.NewAnthropicToGCPAnthropicTranslator(schema.Version, modelNameOverride), nil
219+
case filterapi.APISchemaAWSAnthropic:
220+
return translator.NewAnthropicToAWSAnthropicTranslator(schema.Version, modelNameOverride), nil
221+
case filterapi.APISchemaAnthropic:
222+
return translator.NewAnthropicToAnthropicTranslator(schema.Version, modelNameOverride), nil
223+
default:
224+
return nil, fmt.Errorf("/v1/messages endpoint only supports backends that return native Anthropic format (Anthropic, GCPAnthropic, AWSAnthropic). Backend %s uses different model format", schema.Name)
225+
}
226+
}
227+
228+
// ParseBody implements [EndpointSpec.ParseBody].
229+
func (RerankEndpointSpec) ParseBody(
230+
body []byte,
231+
_ bool,
232+
) (internalapi.OriginalModel, *cohereschema.RerankV2Request, bool, []byte, error) {
233+
var req cohereschema.RerankV2Request
234+
if err := json.Unmarshal(body, &req); err != nil {
235+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal rerank request: %w", err)
236+
}
237+
return req.Model, &req, false, nil, nil
238+
}
239+
240+
// GetTranslator implements [EndpointSpec.GetTranslator].
241+
func (RerankEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.CohereRerankTranslator, error) {
242+
switch schema.Name {
243+
case filterapi.APISchemaCohere:
244+
return translator.NewRerankCohereToCohereTranslator(schema.Version, modelNameOverride), nil
245+
default:
246+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
247+
}
248+
}

0 commit comments

Comments
 (0)