Skip to content

Commit 8d3310e

Browse files
authored
fix: move stop sequence decoding logic out of upstream filter (envoyproxy#1238)
**Description** Currently `stop` field is defined as `any` type, when step field is passed as string array it is decoded as `[]any` which then errors out in the `processStop` function in the upstream filter. The fix is to define the stop union type on the openai compatible API schema and decode right there instead of in the translator. **Related Issues/PRs (if applicable)** Fixes envoyproxy#1237 --------- Signed-off-by: Dan Sun <[email protected]>
1 parent 40544e2 commit 8d3310e

File tree

13 files changed

+183
-186
lines changed

13 files changed

+183
-186
lines changed

internal/apischema/awsbedrock/awsbedrock.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ type InferenceConfiguration struct {
4646

4747
// A list of stop sequences. A stop sequence is a sequence of characters that
4848
// causes the model to stop generating the response.
49-
StopSequences []*string `json:"stopSequences,omitempty"`
49+
StopSequences []string `json:"stopSequences,omitempty"`
5050

5151
// The likelihood of the model selecting higher-probability options while generating
5252
// a response. A lower value makes the model more likely to choose higher-probability

internal/apischema/openai/openai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ type ChatCompletionRequest struct {
874874
// Stop string / array / null Defaults to null
875875
// Up to 4 sequences where the API will stop generating further tokens.
876876
// Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop
877-
Stop any `json:"stop,omitempty"`
877+
Stop openai.ChatCompletionNewParamsStopUnion `json:"stop,omitzero"`
878878

879879
// Stream: If set, partial message deltas will be sent, like in ChatGPT.
880880
// Docs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream

internal/apischema/openai/openai_test.go

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"time"
1212

1313
"github.com/google/go-cmp/cmp"
14+
"github.com/google/go-cmp/cmp/cmpopts"
1415
"github.com/openai/openai-go/v2"
16+
"github.com/openai/openai-go/v2/packages/param"
1517
"github.com/stretchr/testify/require"
1618
"k8s.io/utils/ptr"
1719
)
@@ -381,10 +383,12 @@ func TestOpenAIChatCompletionMessageUnmarshal(t *testing.T) {
381383
},
382384
MaxCompletionTokens: ptr.To[int64](1024),
383385
ParallelToolCalls: ptr.To(true),
384-
Stop: []any{"\n", "stop"},
385-
ServiceTier: openai.ChatCompletionNewParamsServiceTierFlex,
386-
Verbosity: openai.ChatCompletionNewParamsVerbosityLow,
387-
ReasoningEffort: openai.ReasoningEffortLow,
386+
Stop: openai.ChatCompletionNewParamsStopUnion{
387+
OfStringArray: []string{"\n", "stop"},
388+
},
389+
ServiceTier: openai.ChatCompletionNewParamsServiceTierFlex,
390+
Verbosity: openai.ChatCompletionNewParamsVerbosityLow,
391+
ReasoningEffort: openai.ReasoningEffortLow,
388392
},
389393
},
390394
{
@@ -404,7 +408,31 @@ func TestOpenAIChatCompletionMessageUnmarshal(t *testing.T) {
404408
},
405409
},
406410
},
407-
Stop: "stop",
411+
Stop: openai.ChatCompletionNewParamsStopUnion{
412+
OfString: openai.Opt[string]("stop"),
413+
},
414+
},
415+
},
416+
{
417+
name: "stop as array",
418+
in: []byte(`{
419+
"model": "gpu-o4",
420+
"messages": [{"role": "user", "content": "hello"}],
421+
"stop": ["</s>", "__end_tag__", "<|eot_id|>", "[answer_end]"]
422+
}`),
423+
out: &ChatCompletionRequest{
424+
Model: "gpu-o4",
425+
Messages: []ChatCompletionMessageParamUnion{
426+
{
427+
OfUser: &ChatCompletionUserMessageParam{
428+
Role: ChatMessageRoleUser,
429+
Content: StringOrUserRoleContentUnion{Value: "hello"},
430+
},
431+
},
432+
},
433+
Stop: openai.ChatCompletionNewParamsStopUnion{
434+
OfStringArray: []string{"</s>", "__end_tag__", "<|eot_id|>", "[answer_end]"},
435+
},
408436
},
409437
},
410438
{
@@ -438,8 +466,10 @@ func TestOpenAIChatCompletionMessageUnmarshal(t *testing.T) {
438466
return
439467
}
440468
require.NoError(t, err)
441-
if !cmp.Equal(&chatCompletion, tc.out) {
442-
t.Errorf("UnmarshalOpenAIRequest(), diff(got, expected) = %s\n", cmp.Diff(&chatCompletion, tc.out))
469+
if !cmp.Equal(&chatCompletion, tc.out,
470+
cmpopts.IgnoreUnexported(openai.ChatCompletionNewParamsStopUnion{}, param.Opt[string]{})) {
471+
t.Errorf("UnmarshalOpenAIRequest(), diff(got, expected) = %s\n", cmp.Diff(&chatCompletion, tc.out,
472+
cmpopts.IgnoreUnexported(openai.ChatCompletionNewParamsStopUnion{}, param.Opt[string]{})))
443473
}
444474
})
445475
}

internal/apischema/openai/vendor_fields_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
"github.com/anthropics/anthropic-sdk-go"
1313
"github.com/google/go-cmp/cmp"
1414
"github.com/google/go-cmp/cmp/cmpopts"
15+
"github.com/openai/openai-go/v2"
16+
"github.com/openai/openai-go/v2/packages/param"
1517
"github.com/stretchr/testify/require"
1618
"google.golang.org/genai"
1719
"k8s.io/utils/ptr"
@@ -233,7 +235,8 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
233235
}
234236

235237
require.NoError(t, err)
236-
if diff := cmp.Diff(tt.expected, &actual, cmpopts.IgnoreUnexported(anthropic.ThinkingConfigEnabledParam{}, anthropic.ThinkingConfigParamUnion{})); diff != "" {
238+
if diff := cmp.Diff(tt.expected, &actual, cmpopts.IgnoreUnexported(anthropic.ThinkingConfigEnabledParam{}, anthropic.ThinkingConfigParamUnion{},
239+
openai.ChatCompletionNewParamsStopUnion{}, param.Opt[string]{})); diff != "" {
237240
t.Errorf("ChatCompletionRequest mismatch (-expected +actual):\n%s", diff)
238241
}
239242
})

internal/extproc/translator/gemini_helper.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -442,18 +442,10 @@ func openAIReqToGeminiGenerationConfig(openAIReq *openai.ChatCompletionRequest)
442442
if openAIReq.FrequencyPenalty != nil {
443443
gc.FrequencyPenalty = openAIReq.FrequencyPenalty
444444
}
445-
stopSeq, err := processStop(openAIReq.Stop)
446-
if err != nil {
447-
return nil, err
448-
}
449-
if len(stopSeq) > 0 {
450-
var stops []string
451-
for _, s := range stopSeq {
452-
if s != nil {
453-
stops = append(stops, *s)
454-
}
455-
}
456-
gc.StopSequences = stops
445+
if openAIReq.Stop.OfString.Valid() {
446+
gc.StopSequences = []string{openAIReq.Stop.OfString.String()}
447+
} else if openAIReq.Stop.OfStringArray != nil {
448+
gc.StopSequences = openAIReq.Stop.OfStringArray
457449
}
458450
return gc, nil
459451
}

internal/extproc/translator/gemini_helper_test.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/google/go-cmp/cmp"
1313
"github.com/google/go-cmp/cmp/cmpopts"
14+
openaigo "github.com/openai/openai-go/v2"
1415
"github.com/stretchr/testify/assert"
1516
"github.com/stretchr/testify/require"
1617
"google.golang.org/genai"
@@ -737,7 +738,9 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) {
737738
MaxTokens: ptr.To(int64(256)),
738739
PresencePenalty: ptr.To(float32(1.1)),
739740
FrequencyPenalty: ptr.To(float32(0.5)),
740-
Stop: []*string{ptr.To("stop1"), ptr.To("stop2")},
741+
Stop: openaigo.ChatCompletionNewParamsStopUnion{
742+
OfStringArray: []string{"stop1", "stop2"},
743+
},
741744
},
742745
expectedGenerationConfig: &genai.GenerationConfig{
743746
Temperature: ptr.To(float32(0.7)),
@@ -757,6 +760,17 @@ func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) {
757760
input: &openai.ChatCompletionRequest{},
758761
expectedGenerationConfig: &genai.GenerationConfig{},
759762
},
763+
{
764+
name: "stop sequences",
765+
input: &openai.ChatCompletionRequest{
766+
Stop: openaigo.ChatCompletionNewParamsStopUnion{
767+
OfString: openaigo.Opt[string]("stop1"),
768+
},
769+
},
770+
expectedGenerationConfig: &genai.GenerationConfig{
771+
StopSequences: []string{"stop1"},
772+
},
773+
},
760774
{
761775
name: "text",
762776
input: &openai.ChatCompletionRequest{

internal/extproc/translator/openai_awsbedrock.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,10 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(_ []byte, ope
7878

7979
bedrockReq.InferenceConfig.MaxTokens = cmp.Or(openAIReq.MaxCompletionTokens, openAIReq.MaxTokens)
8080

81-
stopSequence, err := processStop(openAIReq.Stop)
82-
if err != nil {
83-
return
84-
}
85-
if len(stopSequence) > 0 {
86-
bedrockReq.InferenceConfig.StopSequences = stopSequence
81+
if openAIReq.Stop.OfString.Valid() {
82+
bedrockReq.InferenceConfig.StopSequences = []string{openAIReq.Stop.OfString.String()}
83+
} else if openAIReq.Stop.OfStringArray != nil {
84+
bedrockReq.InferenceConfig.StopSequences = openAIReq.Stop.OfStringArray
8785
}
8886

8987
// Handle Anthropic vendor fields if present. Currently only supports thinking fields.

internal/extproc/translator/openai_awsbedrock_test.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
2020
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2121
"github.com/google/go-cmp/cmp"
22+
openaigo "github.com/openai/openai-go/v2"
2223
"github.com/stretchr/testify/require"
2324
"k8s.io/utils/ptr"
2425

@@ -712,11 +713,47 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
712713
},
713714
},
714715
},
715-
Stop: []*string{ptr.To("stop_only")},
716+
Stop: openaigo.ChatCompletionNewParamsStopUnion{
717+
OfString: openaigo.Opt[string]("stop_only"),
718+
},
719+
},
720+
output: awsbedrock.ConverseInput{
721+
InferenceConfig: &awsbedrock.InferenceConfiguration{
722+
StopSequences: []string{"stop_only"},
723+
},
724+
Messages: []*awsbedrock.Message{
725+
{
726+
Role: openai.ChatMessageRoleUser,
727+
Content: []*awsbedrock.ContentBlock{
728+
{
729+
Text: ptr.To("from-user"),
730+
},
731+
},
732+
},
733+
},
734+
},
735+
},
736+
{
737+
name: "test stop sequence",
738+
input: openai.ChatCompletionRequest{
739+
Model: "gpt-4o",
740+
Messages: []openai.ChatCompletionMessageParamUnion{
741+
{
742+
OfUser: &openai.ChatCompletionUserMessageParam{
743+
Content: openai.StringOrUserRoleContentUnion{
744+
Value: "from-user",
745+
},
746+
Role: openai.ChatMessageRoleUser,
747+
},
748+
},
749+
},
750+
Stop: openaigo.ChatCompletionNewParamsStopUnion{
751+
OfStringArray: []string{"stop1", "stop2"},
752+
},
716753
},
717754
output: awsbedrock.ConverseInput{
718755
InferenceConfig: &awsbedrock.InferenceConfiguration{
719-
StopSequences: []*string{ptr.To("stop_only")},
756+
StopSequences: []string{"stop1", "stop2"},
720757
},
721758
Messages: []*awsbedrock.Message{
722759
{

internal/extproc/translator/openai_gcpanthropic.go

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -515,20 +515,10 @@ func buildAnthropicParams(openAIReq *openai.ChatCompletionRequest) (params *anth
515515
if openAIReq.TopP != nil {
516516
params.TopP = anthropic.Float(*openAIReq.TopP)
517517
}
518-
519-
// Handle stop sequences.
520-
stopSequences, err := processStop(openAIReq.Stop)
521-
if err != nil {
522-
return nil, err
523-
}
524-
if len(stopSequences) > 0 {
525-
var stops []string
526-
for _, s := range stopSequences {
527-
if s != nil {
528-
stops = append(stops, *s)
529-
}
530-
}
531-
params.StopSequences = stops
518+
if openAIReq.Stop.OfString.Valid() {
519+
params.StopSequences = []string{openAIReq.Stop.OfString.String()}
520+
} else if openAIReq.Stop.OfStringArray != nil {
521+
params.StopSequences = openAIReq.Stop.OfStringArray
532522
}
533523

534524
// 5. Handle Vendor specific fields.

internal/extproc/translator/openai_gcpanthropic_test.go

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/anthropics/anthropic-sdk-go/shared/constant"
1919
anthropicVertex "github.com/anthropics/anthropic-sdk-go/vertex"
2020
"github.com/google/go-cmp/cmp"
21+
openaigo "github.com/openai/openai-go/v2"
2122
"github.com/stretchr/testify/require"
2223
"github.com/tidwall/gjson"
2324
"k8s.io/utils/ptr"
@@ -209,6 +210,44 @@ func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_RequestBody(t *testing.T
209210
require.Equal(t, expectedPath, string(pathHeader[0].Header.RawValue))
210211
})
211212

213+
t.Run("Test message param", func(t *testing.T) {
214+
openaiRequest := &openai.ChatCompletionRequest{
215+
Model: claudeTestModel,
216+
Messages: []openai.ChatCompletionMessageParamUnion{},
217+
Temperature: ptr.To(0.1),
218+
MaxTokens: ptr.To(int64(100)),
219+
TopP: ptr.To(0.1),
220+
Stop: openaigo.ChatCompletionNewParamsStopUnion{
221+
OfStringArray: []string{"stop1", "stop2"},
222+
},
223+
}
224+
messageParam, err := buildAnthropicParams(openaiRequest)
225+
require.NoError(t, err)
226+
require.Equal(t, int64(100), messageParam.MaxTokens)
227+
require.Equal(t, "0.1", messageParam.TopP.String())
228+
require.Equal(t, "0.1", messageParam.Temperature.String())
229+
require.Equal(t, []string{"stop1", "stop2"}, messageParam.StopSequences)
230+
})
231+
232+
t.Run("Test single stop", func(t *testing.T) {
233+
openaiRequest := &openai.ChatCompletionRequest{
234+
Model: claudeTestModel,
235+
Messages: []openai.ChatCompletionMessageParamUnion{},
236+
Temperature: ptr.To(0.1),
237+
MaxTokens: ptr.To(int64(100)),
238+
TopP: ptr.To(0.1),
239+
Stop: openaigo.ChatCompletionNewParamsStopUnion{
240+
OfString: openaigo.Opt[string]("stop1"),
241+
},
242+
}
243+
messageParam, err := buildAnthropicParams(openaiRequest)
244+
require.NoError(t, err)
245+
require.Equal(t, int64(100), messageParam.MaxTokens)
246+
require.Equal(t, "0.1", messageParam.TopP.String())
247+
require.Equal(t, "0.1", messageParam.Temperature.String())
248+
require.Equal(t, []string{"stop1"}, messageParam.StopSequences)
249+
})
250+
212251
t.Run("Invalid Temperature (above bound)", func(t *testing.T) {
213252
invalidTempReq := &openai.ChatCompletionRequest{
214253
Model: claudeTestModel,
@@ -847,12 +886,6 @@ func TestHelperFunctions(t *testing.T) {
847886
require.Error(t, err)
848887
require.Contains(t, err.Error(), "invalid anthropic role")
849888
})
850-
851-
t.Run("process stop with nil", func(t *testing.T) {
852-
val, err := processStop(nil)
853-
require.NoError(t, err)
854-
require.Nil(t, val)
855-
})
856889
}
857890

858891
func TestTranslateOpenAItoAnthropicTools(t *testing.T) {

0 commit comments

Comments
 (0)