Skip to content

Commit a0e8f9d

Browse files
authored
Merge branch 'main' into gcp-anthropic
2 parents 1249077 + 1b06346 commit a0e8f9d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1936
-7562
lines changed

api/v1alpha1/mcp_route.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,13 @@ type ProtectedResourceMetadata struct {
262262
// +optional
263263
ResourceName *string `json:"resourceName,omitempty"`
264264

265-
// ScopesSupported is a list of OAuth 2.0 scopes that the resource server supports.
265+
// ScopesSupported defines the minimal set of scopes required for the basic functionality of the MCPRoute.
266+
// It should avoid broad or overly permissive scopes to prevent clients from requesting tokens with excessive privileges.
267+
//
268+
// If an operation requires additional scopes that are not present in the access token, the client will receive a
269+
// 403 Forbidden response that includes the required scopes in the `scope` field of the `WWW-Authenticate` header.
270+
// This enables incremental privilege elevation through targeted `WWW-Authenticate: scope="..."` challenges when
271+
// privileged operations are first attempted.
266272
//
267273
// +kubebuilder:validation:Optional
268274
// +kubebuilder:validation:MaxItems=32

cmd/aigw/run.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,15 @@ func run(ctx context.Context, c cmdRun, o *runOpts, stdout, stderr io.Writer) er
138138
// Do the translation of the given AI Gateway resources Yaml into Envoy Gateway resources and write them to the file.
139139
resourcesBuf := &bytes.Buffer{}
140140
runCtx := &runCmdContext{
141-
isDebug: c.Debug,
142-
envoyGatewayResourcesOut: resourcesBuf,
143-
stderrLogger: debugLogger,
144-
stderr: stderr,
145-
tmpdir: filepath.Dir(o.logPath), // runDir
146-
udsPath: o.extprocUDSPath,
147-
adminPort: c.AdminPort,
148-
extProcLauncher: o.extProcLauncher,
141+
isDebug: c.Debug,
142+
envoyGatewayResourcesOut: resourcesBuf,
143+
stderrLogger: debugLogger,
144+
stderr: stderr,
145+
tmpdir: filepath.Dir(o.logPath), // runDir
146+
udsPath: o.extprocUDSPath,
147+
adminPort: c.AdminPort,
148+
extProcLauncher: o.extProcLauncher,
149+
mcpSessionEncryptionIterations: c.MCPSessionEncryptionIterations,
149150
}
150151
// If any of the configured MCP servers is using stdio, set up the streamable HTTP proxies for them
151152
if err = proxyStdioMCPServers(ctx, debugLogger, c.mcpConfig); err != nil {

cmd/extproc/mainlib/main.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"google.golang.org/grpc"
2626
"google.golang.org/grpc/health/grpc_health_v1"
2727

28+
"github.com/envoyproxy/ai-gateway/internal/endpointspec"
2829
"github.com/envoyproxy/ai-gateway/internal/extproc"
2930
"github.com/envoyproxy/ai-gateway/internal/filterapi"
3031
"github.com/envoyproxy/ai-gateway/internal/internalapi"
@@ -254,13 +255,19 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
254255
if err != nil {
255256
return fmt.Errorf("failed to create external processor server: %w", err)
256257
}
257-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/chat/completions"), extproc.ChatCompletionProcessorFactory(chatCompletionMetricsFactory, tracing.ChatCompletionTracer()))
258-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/completions"), extproc.CompletionsProcessorFactory(completionMetricsFactory, tracing.CompletionTracer()))
259-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/embeddings"), extproc.EmbeddingsProcessorFactory(embeddingsMetricsFactory, tracing.EmbeddingsTracer()))
260-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/images/generations"), extproc.ImageGenerationProcessorFactory(imageGenerationMetricsFactory, tracing.ImageGenerationTracer()))
261-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Cohere, "/v2/rerank"), extproc.RerankProcessorFactory(rerankMetricsFactory, tracing.RerankTracer()))
258+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/chat/completions"), extproc.NewFactory(
259+
chatCompletionMetricsFactory, tracing.ChatCompletionTracer(), endpointspec.ChatCompletionsEndpointSpec{}))
260+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/completions"), extproc.NewFactory(
261+
completionMetricsFactory, tracing.CompletionTracer(), endpointspec.CompletionsEndpointSpec{}))
262+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/embeddings"), extproc.NewFactory(
263+
embeddingsMetricsFactory, tracing.EmbeddingsTracer(), endpointspec.EmbeddingsEndpointSpec{}))
264+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/images/generations"), extproc.NewFactory(
265+
imageGenerationMetricsFactory, tracing.ImageGenerationTracer(), endpointspec.ImageGenerationEndpointSpec{}))
266+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Cohere, "/v2/rerank"), extproc.NewFactory(
267+
rerankMetricsFactory, tracing.RerankTracer(), endpointspec.RerankEndpointSpec{}))
262268
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.OpenAI, "/v1/models"), extproc.NewModelsProcessor)
263-
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Anthropic, "/v1/messages"), extproc.MessagesProcessorFactory(messagesMetricsFactory, tracing.MessageTracer()))
269+
server.Register(path.Join(flags.rootPrefix, endpointPrefixes.Anthropic, "/v1/messages"), extproc.NewFactory(
270+
messagesMetricsFactory, tracing.MessageTracer(), endpointspec.MessagesEndpointSpec{}))
264271

265272
if watchErr := filterapi.StartConfigWatcher(ctx, flags.configPath, server, l, time.Second*5); watchErr != nil {
266273
return fmt.Errorf("failed to start config watcher: %w", watchErr)

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ require (
2828
github.com/google/go-cmp v0.7.0
2929
github.com/google/jsonschema-go v0.3.0
3030
github.com/google/uuid v1.6.0
31-
github.com/modelcontextprotocol/go-sdk v1.0.0
31+
github.com/modelcontextprotocol/go-sdk v1.1.0
3232
github.com/openai/openai-go v1.12.0
3333
github.com/openai/openai-go/v2 v2.7.1
3434
github.com/prometheus/client_golang v1.23.2

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g
331331
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
332332
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
333333
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
334-
github.com/modelcontextprotocol/go-sdk v1.0.0 h1:Z4MSjLi38bTgLrd/LjSmofqRqyBiVKRyQSJgw8q8V74=
335-
github.com/modelcontextprotocol/go-sdk v1.0.0/go.mod h1:nYtYQroQ2KQiM0/SbyEPUWQ6xs4B95gJjEalc9AQyOs=
334+
github.com/modelcontextprotocol/go-sdk v1.1.0 h1:Qjayg53dnKC4UZ+792W21e4BpwEZBzwgRW6LrjLWSwA=
335+
github.com/modelcontextprotocol/go-sdk v1.1.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
336336
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
337337
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
338338
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

internal/controller/mcp_route_security_policy.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,11 @@ func buildWWWAuthenticateHeaderValue(metadata *aigv1a1.ProtectedResourceMetadata
323323
// Add resource_metadata as per RFC 9728 Section 5.1.
324324
headerValue = fmt.Sprintf(`%s, resource_metadata="%s"`, headerValue, resourceMetadataURL)
325325

326+
if len(metadata.ScopesSupported) > 0 {
327+
// Add scope as per RFC 6750 Section 3.
328+
headerValue = fmt.Sprintf(`%s, scope="%s"`, headerValue, strings.Join(metadata.ScopesSupported, " "))
329+
}
330+
326331
return headerValue
327332
}
328333

internal/controller/mcp_route_security_policy_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,30 @@ func Test_buildWWWAuthenticateHeaderValue(t *testing.T) {
654654
},
655655
expected: `Bearer error="invalid_request", error_description="No access token was provided in this request", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource/v1/mcp/endpoint"`,
656656
},
657+
{
658+
name: "with empty scopes supported",
659+
metadata: &aigv1a1.ProtectedResourceMetadata{
660+
Resource: "https://api.example.com/mcp",
661+
ScopesSupported: []string{},
662+
},
663+
expected: `Bearer error="invalid_request", error_description="No access token was provided in this request", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource/mcp"`,
664+
},
665+
{
666+
name: "with single scope supported",
667+
metadata: &aigv1a1.ProtectedResourceMetadata{
668+
Resource: "https://api.example.com/mcp",
669+
ScopesSupported: []string{"read"},
670+
},
671+
expected: `Bearer error="invalid_request", error_description="No access token was provided in this request", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource/mcp", scope="read"`,
672+
},
673+
{
674+
name: "with multiple scopes supported",
675+
metadata: &aigv1a1.ProtectedResourceMetadata{
676+
Resource: "https://api.example.com/mcp",
677+
ScopesSupported: []string{"read", "write"},
678+
},
679+
expected: `Bearer error="invalid_request", error_description="No access token was provided in this request", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource/mcp", scope="read write"`,
680+
},
657681
}
658682

659683
for _, tt := range tests {
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// Copyright Envoy AI Gateway Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
// The full text of the Apache license is available in the LICENSE file at
4+
// the root of the repo.
5+
6+
// Package endpointspec defines the EndpointSpec which is to bundle the translator, tracing
7+
// and most importantly request and response types for different API endpoints.
8+
package endpointspec
9+
10+
import (
11+
"encoding/json"
12+
"fmt"
13+
14+
openaisdk "github.com/openai/openai-go/v2"
15+
"github.com/tidwall/sjson"
16+
17+
"github.com/envoyproxy/ai-gateway/internal/apischema/anthropic"
18+
cohereschema "github.com/envoyproxy/ai-gateway/internal/apischema/cohere"
19+
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
20+
"github.com/envoyproxy/ai-gateway/internal/filterapi"
21+
"github.com/envoyproxy/ai-gateway/internal/internalapi"
22+
tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api"
23+
"github.com/envoyproxy/ai-gateway/internal/translator"
24+
)
25+
26+
type (
27+
// Spec defines methods for parsing request bodies and selecting translators
28+
// for different API endpoints.
29+
//
30+
// Type Parameters:
31+
// * ReqT: The request type.
32+
// * RespT: The response type.
33+
// * RespChunkT: The chunk type for streaming responses.
34+
//
35+
// This must be implemented by specific endpoint handlers to provide
36+
// custom logic for parsing and translation.
37+
Spec[ReqT, RespT, RespChunkT any] interface {
38+
// ParseBody parses the request body and returns the original model,
39+
// the parsed request, whether the request is streaming, any mutated body,
40+
// and an error if parsing fails.
41+
//
42+
// Parameters:
43+
// * body: The raw request body as a byte slice.
44+
// * costConfigured: A boolean indicating if cost metrics are configured.
45+
//
46+
// Returns:
47+
// * originalModel: The original model specified in the request.
48+
// * req: The parsed request of type ReqT.
49+
// * stream: A boolean indicating if the request is for streaming responses.
50+
// * mutatedBody: The possibly mutated request body as a byte slice. Or nil if no mutation is needed.
51+
// * err: An error if parsing fails.
52+
ParseBody(body []byte, costConfigured bool) (originalModel internalapi.OriginalModel, req *ReqT, stream bool, mutatedBody []byte, err error)
53+
// GetTranslator selects the appropriate translator based on the output API schema
54+
// and an optional model name override.
55+
//
56+
// Parameters:
57+
// * out: The output API schema for which the translator is needed.
58+
// * modelNameOverride: An optional model name to override the one specified in the request.
59+
//
60+
// Returns:
61+
// * translator: The selected translator of type Translator[ReqT, RespT, RespChunkT].
62+
// * err: An error if translator selection fails.
63+
GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.Translator[ReqT, tracing.Span[RespT, RespChunkT]], error)
64+
}
65+
// ChatCompletionsEndpointSpec implements EndpointSpec for /v1/chat/completions.
66+
ChatCompletionsEndpointSpec struct{}
67+
// CompletionsEndpointSpec implements EndpointSpec for /v1/completions.
68+
CompletionsEndpointSpec struct{}
69+
// EmbeddingsEndpointSpec implements EndpointSpec for /v1/embeddings.
70+
EmbeddingsEndpointSpec struct{}
71+
// ImageGenerationEndpointSpec implements EndpointSpec for /v1/images/generations.
72+
ImageGenerationEndpointSpec struct{}
73+
// MessagesEndpointSpec implements EndpointSpec for /v1/messages.
74+
MessagesEndpointSpec struct{}
75+
// RerankEndpointSpec implements EndpointSpec for /v2/rerank.
76+
RerankEndpointSpec struct{}
77+
)
78+
79+
// ParseBody implements [EndpointSpec.ParseBody].
80+
func (ChatCompletionsEndpointSpec) ParseBody(
81+
body []byte,
82+
costConfigured bool,
83+
) (internalapi.OriginalModel, *openai.ChatCompletionRequest, bool, []byte, error) {
84+
var req openai.ChatCompletionRequest
85+
if err := json.Unmarshal(body, &req); err != nil {
86+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal chat completion request: %w", err)
87+
}
88+
var mutatedBody []byte
89+
if req.Stream && costConfigured && (req.StreamOptions == nil || !req.StreamOptions.IncludeUsage) {
90+
// If the request is a streaming request and cost metrics are configured, we need to include usage in the response
91+
// to avoid the bypassing of the token usage calculation.
92+
req.StreamOptions = &openai.StreamOptions{IncludeUsage: true}
93+
// Rewrite the original bytes to include the stream_options.include_usage=true so that forcing the request body
94+
// mutation, which uses this raw body, will also result in the stream_options.include_usage=true.
95+
var err error
96+
mutatedBody, err = sjson.SetBytesOptions(body, "stream_options.include_usage", true, &sjson.Options{
97+
Optimistic: true,
98+
// Note: it is safe to do in-place replacement since this route level processor is executed once per request,
99+
// and the result can be safely shared among possible multiple retries.
100+
ReplaceInPlace: true,
101+
})
102+
if err != nil {
103+
return "", nil, false, nil, fmt.Errorf("failed to set stream_options: %w", err)
104+
}
105+
}
106+
return req.Model, &req, req.Stream, mutatedBody, nil
107+
}
108+
109+
// GetTranslator implements [EndpointSpec.GetTranslator].
110+
func (ChatCompletionsEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.OpenAIChatCompletionTranslator, error) {
111+
switch schema.Name {
112+
case filterapi.APISchemaOpenAI:
113+
return translator.NewChatCompletionOpenAIToOpenAITranslator(schema.Version, modelNameOverride), nil
114+
case filterapi.APISchemaAWSBedrock:
115+
return translator.NewChatCompletionOpenAIToAWSBedrockTranslator(modelNameOverride), nil
116+
case filterapi.APISchemaAzureOpenAI:
117+
return translator.NewChatCompletionOpenAIToAzureOpenAITranslator(schema.Version, modelNameOverride), nil
118+
case filterapi.APISchemaGCPVertexAI:
119+
return translator.NewChatCompletionOpenAIToGCPVertexAITranslator(modelNameOverride), nil
120+
case filterapi.APISchemaGCPAnthropic:
121+
return translator.NewChatCompletionOpenAIToGCPAnthropicTranslator(schema.Version, modelNameOverride), nil
122+
default:
123+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
124+
}
125+
}
126+
127+
// ParseBody implements [EndpointSpec.ParseBody].
128+
func (CompletionsEndpointSpec) ParseBody(
129+
body []byte,
130+
_ bool,
131+
) (internalapi.OriginalModel, *openai.CompletionRequest, bool, []byte, error) {
132+
var openAIReq openai.CompletionRequest
133+
if err := json.Unmarshal(body, &openAIReq); err != nil {
134+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal completion request: %w", err)
135+
}
136+
return openAIReq.Model, &openAIReq, openAIReq.Stream, nil, nil
137+
}
138+
139+
// GetTranslator implements [EndpointSpec.GetTranslator].
140+
func (CompletionsEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.OpenAICompletionTranslator, error) {
141+
switch schema.Name {
142+
case filterapi.APISchemaOpenAI:
143+
return translator.NewCompletionOpenAIToOpenAITranslator(schema.Version, modelNameOverride), nil
144+
default:
145+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
146+
}
147+
}
148+
149+
// ParseBody implements [EndpointSpec.ParseBody].
150+
func (EmbeddingsEndpointSpec) ParseBody(
151+
body []byte,
152+
_ bool,
153+
) (internalapi.OriginalModel, *openai.EmbeddingRequest, bool, []byte, error) {
154+
var openAIReq openai.EmbeddingRequest
155+
if err := json.Unmarshal(body, &openAIReq); err != nil {
156+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal embedding request: %w", err)
157+
}
158+
return openAIReq.Model, &openAIReq, false, nil, nil
159+
}
160+
161+
// GetTranslator implements [EndpointSpec.GetTranslator].
162+
func (EmbeddingsEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.OpenAIEmbeddingTranslator, error) {
163+
switch schema.Name {
164+
case filterapi.APISchemaOpenAI:
165+
return translator.NewEmbeddingOpenAIToOpenAITranslator(schema.Version, modelNameOverride), nil
166+
case filterapi.APISchemaAzureOpenAI:
167+
return translator.NewEmbeddingOpenAIToAzureOpenAITranslator(schema.Version, modelNameOverride), nil
168+
default:
169+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
170+
}
171+
}
172+
173+
func (ImageGenerationEndpointSpec) ParseBody(
174+
body []byte,
175+
_ bool,
176+
) (internalapi.OriginalModel, *openaisdk.ImageGenerateParams, bool, []byte, error) {
177+
var openAIReq openaisdk.ImageGenerateParams
178+
if err := json.Unmarshal(body, &openAIReq); err != nil {
179+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal image generation request: %w", err)
180+
}
181+
return openAIReq.Model, &openAIReq, false, nil, nil
182+
}
183+
184+
// GetTranslator implements [EndpointSpec.GetTranslator].
185+
func (ImageGenerationEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.OpenAIImageGenerationTranslator, error) {
186+
switch schema.Name {
187+
case filterapi.APISchemaOpenAI:
188+
return translator.NewImageGenerationOpenAIToOpenAITranslator(schema.Version, modelNameOverride), nil
189+
default:
190+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
191+
}
192+
}
193+
194+
// ParseBody implements [EndpointSpec.ParseBody].
195+
func (MessagesEndpointSpec) ParseBody(
196+
body []byte,
197+
_ bool,
198+
) (internalapi.OriginalModel, *anthropic.MessagesRequest, bool, []byte, error) {
199+
var anthropicReq anthropic.MessagesRequest
200+
if err := json.Unmarshal(body, &anthropicReq); err != nil {
201+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal Anthropic Messages body: %w", err)
202+
}
203+
204+
model := anthropicReq.GetModel()
205+
if model == "" {
206+
return "", nil, false, nil, fmt.Errorf("model field is required in Anthropic request")
207+
}
208+
209+
stream := anthropicReq.GetStream()
210+
return model, &anthropicReq, stream, nil, nil
211+
}
212+
213+
// GetTranslator implements [EndpointSpec.GetTranslator].
214+
func (MessagesEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.AnthropicMessagesTranslator, error) {
215+
// Messages processor only supports Anthropic-native translators.
216+
switch schema.Name {
217+
case filterapi.APISchemaGCPAnthropic:
218+
return translator.NewAnthropicToGCPAnthropicTranslator(schema.Version, modelNameOverride), nil
219+
case filterapi.APISchemaAWSAnthropic:
220+
return translator.NewAnthropicToAWSAnthropicTranslator(schema.Version, modelNameOverride), nil
221+
case filterapi.APISchemaAnthropic:
222+
return translator.NewAnthropicToAnthropicTranslator(schema.Version, modelNameOverride), nil
223+
default:
224+
return nil, fmt.Errorf("/v1/messages endpoint only supports backends that return native Anthropic format (Anthropic, GCPAnthropic, AWSAnthropic). Backend %s uses different model format", schema.Name)
225+
}
226+
}
227+
228+
// ParseBody implements [EndpointSpec.ParseBody].
229+
func (RerankEndpointSpec) ParseBody(
230+
body []byte,
231+
_ bool,
232+
) (internalapi.OriginalModel, *cohereschema.RerankV2Request, bool, []byte, error) {
233+
var req cohereschema.RerankV2Request
234+
if err := json.Unmarshal(body, &req); err != nil {
235+
return "", nil, false, nil, fmt.Errorf("failed to unmarshal rerank request: %w", err)
236+
}
237+
return req.Model, &req, false, nil, nil
238+
}
239+
240+
// GetTranslator implements [EndpointSpec.GetTranslator].
241+
func (RerankEndpointSpec) GetTranslator(schema filterapi.VersionedAPISchema, modelNameOverride string) (translator.CohereRerankTranslator, error) {
242+
switch schema.Name {
243+
case filterapi.APISchemaCohere:
244+
return translator.NewRerankCohereToCohereTranslator(schema.Version, modelNameOverride), nil
245+
default:
246+
return nil, fmt.Errorf("unsupported API schema: backend=%s", schema)
247+
}
248+
}

0 commit comments

Comments
 (0)