Skip to content

Commit 96a2fdd

Browse files
hustxiayangyuzisun
andauthored
feat: provide unified thinking config among anthropic and gemini models (#1461)
**Description** `thinking_config` among anthropic and gemini models are similar, for example: `thinking` and `thinking_config`; `budget_tokens` and `thinkingBudget`. Our users do not need to set up different thinking configs for different models as we can provide a unified interface among different providers. **Related Issues/PRs (if applicable)** Related to #1463 --------- Signed-off-by: yxia216 <[email protected]> Co-authored-by: Dan Sun <[email protected]>
1 parent 9c7af75 commit 96a2fdd

File tree

9 files changed

+181
-169
lines changed

9 files changed

+181
-169
lines changed

docs/proposals/004-vendor-specific-fields/proposal.md

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,20 @@ type ChatCompletionRequest struct {
3636

3737
// Vendor-specific fields are added as inline fields
3838
*GCPVertexAIVendorFields `json:",inline,omitempty"`
39-
*AnthropicVendorFields `json:",inline,omitempty"`
4039
}
4140

4241
// GCPVertexAIVendorFields contains GCP Vertex AI (Gemini) vendor-specific fields.
4342
type GCPVertexAIVendorFields struct {
4443
// GenerationConfig holds Gemini generation configuration options.
45-
GenerationConfig *GCPVertexAIGenerationConfig `json:"generationConfig,omitempty"`
46-
}
47-
48-
// GCPVertexAIGenerationConfig represents Gemini generation configuration options.
49-
type GCPVertexAIGenerationConfig struct {
50-
ThinkingConfig *genai.GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
51-
}
52-
53-
// AnthropicVendorFields contains GCP Anthropic-specific fields.
54-
type AnthropicVendorFields struct {
55-
Thinking *anthropic.ThinkingConfigParamUnion `json:"thinking,omitzero"`
44+
// Currently only a subset of the options are supported.
45+
//
46+
// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig
47+
GenerationConfig *GCPVertexAIGenerationConfig `json:"generationConfig,omitzero"`
48+
49+
// SafetySettings: Safety settings in the request to block unsafe content in the response.
50+
//
51+
// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/SafetySetting
52+
SafetySettings []*genai.SafetySetting `json:"safetySettings,omitzero"`
5653
}
5754
```
5855

internal/apischema/openai/openai.go

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,71 @@ type WebSearchLocation struct {
819819
Country string `json:"country,omitempty"`
820820
}
821821

822+
// ThinkingConfig contains thinking config for reasoning models
823+
type ThinkingUnion struct {
824+
OfEnabled *ThinkingEnabled `json:",omitzero,inline"`
825+
OfDisabled *ThinkingDisabled `json:",omitzero,inline"`
826+
}
827+
828+
type ThinkingEnabled struct {
829+
// Determines how many tokens the model can use for its internal reasoning process.
830+
// Larger budgets can enable more thorough analysis for complex problems, improving
831+
// response quality.
832+
BudgetTokens int64 `json:"budget_tokens"`
833+
// This field can be elided, and will marshal its zero value as "enabled".
834+
Type string `json:"type"`
835+
836+
// Optional. Indicates the thinking budget in tokens.
837+
IncludeThoughts bool `json:"includeThoughts,omitempty"`
838+
}
839+
840+
type ThinkingDisabled struct {
841+
Type string `json:"type,"`
842+
}
843+
844+
// MarshalJSON implements the json.Marshaler interface for ThinkingUnion.
845+
func (t *ThinkingUnion) MarshalJSON() ([]byte, error) {
846+
if t.OfEnabled != nil {
847+
return json.Marshal(t.OfEnabled)
848+
}
849+
if t.OfDisabled != nil {
850+
return json.Marshal(t.OfDisabled)
851+
}
852+
// If both are nil, return an empty object or an error, depending on your desired behavior.
853+
return []byte(`{}`), nil
854+
}
855+
856+
// UnmarshalJSON implements the json.Unmarshaler interface for ThinkingUnion.
857+
func (t *ThinkingUnion) UnmarshalJSON(data []byte) error {
858+
// Use a temporary struct to determine the type
859+
typeResult := gjson.GetBytes(data, "type")
860+
if !typeResult.Exists() {
861+
return errors.New("thinking config does not have a type")
862+
}
863+
864+
// Based on the 'type' field, unmarshal into the correct struct.
865+
typeVal := typeResult.String()
866+
867+
switch typeVal {
868+
case "enabled":
869+
var enabled ThinkingEnabled
870+
if err := json.Unmarshal(data, &enabled); err != nil {
871+
return err
872+
}
873+
t.OfEnabled = &enabled
874+
case "disabled":
875+
var disabled ThinkingDisabled
876+
if err := json.Unmarshal(data, &disabled); err != nil {
877+
return err
878+
}
879+
t.OfDisabled = &disabled
880+
default:
881+
return fmt.Errorf("invalid thinking union type: %s", typeVal)
882+
}
883+
884+
return nil
885+
}
886+
822887
type ChatCompletionRequest struct {
823888
// Messages: A list of messages comprising the conversation so far.
824889
// Depending on the model you use, different message types (modalities) are supported,
@@ -982,9 +1047,6 @@ type ChatCompletionRequest struct {
9821047
// GCPVertexAIVendorFields configures the GCP VertexAI specific fields during schema translation.
9831048
*GCPVertexAIVendorFields `json:",inline,omitempty"`
9841049

985-
// AnthropicVendorFields configures the Anthropic specific fields during schema translation.
986-
*AnthropicVendorFields `json:",inline,omitempty"`
987-
9881050
// GuidedChoice: The output will be exactly one of the choices.
9891051
GuidedChoice []string `json:"guided_choice,omitzero"`
9901052

@@ -993,6 +1055,9 @@ type ChatCompletionRequest struct {
9931055

9941056
// GuidedJSON: The output will follow the JSON schema.
9951057
GuidedJSON json.RawMessage `json:"guided_json,omitzero"`
1058+
1059+
// Thinking: The thinking config for reasoning models
1060+
Thinking *ThinkingUnion `json:"thinking,omitzero"`
9961061
}
9971062

9981063
type StreamOptions struct {
@@ -1578,23 +1643,10 @@ type GCPVertexAIVendorFields struct {
15781643

15791644
// GCPVertexAIGenerationConfig represents Gemini generation configuration options.
15801645
type GCPVertexAIGenerationConfig struct {
1581-
// ThinkingConfig holds Gemini thinking configuration options.
1582-
//
1583-
// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig#ThinkingConfig
1584-
ThinkingConfig *genai.ThinkingConfig `json:"thinkingConfig,omitzero"`
1585-
15861646
// MediaResolution is to set global media resolution in gemini models: https://ai.google.dev/api/caching#MediaResolution
15871647
MediaResolution genai.MediaResolution `json:"media_resolution,omitempty"`
15881648
}
15891649

1590-
// AnthropicVendorFields contains Anthropic vendor-specific fields.
1591-
type AnthropicVendorFields struct {
1592-
// Thinking holds Anthropic thinking configuration options.
1593-
//
1594-
// https://docs.anthropic.com/en/api/messages#body-thinking
1595-
Thinking *anthropic.ThinkingConfigParamUnion `json:"thinking,omitzero"`
1596-
}
1597-
15981650
// ReasoningContentUnion content regarding the reasoning that is carried out by the model.
15991651
// Reasoning refers to a Chain of Thought (CoT) that the model generates to enhance the accuracy of its final response.
16001652
type ReasoningContentUnion struct {

internal/apischema/openai/vendor_fields_test.go

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
"github.com/openai/openai-go/v2/packages/param"
1717
"github.com/stretchr/testify/require"
1818
"google.golang.org/genai"
19-
"k8s.io/utils/ptr"
2019
)
2120

2221
func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
@@ -36,12 +35,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
3635
"content": "Hello, world!"
3736
}
3837
],
39-
"generationConfig": {
40-
"thinkingConfig": {
41-
"includeThoughts": true,
42-
"thinkingBudget": 1000
43-
}
44-
},
4538
"safetySettings": [{
4639
"category": "HARM_CATEGORY_HARASSMENT",
4740
"threshold": "BLOCK_ONLY_HIGH"
@@ -58,12 +51,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
5851
},
5952
},
6053
GCPVertexAIVendorFields: &GCPVertexAIVendorFields{
61-
GenerationConfig: &GCPVertexAIGenerationConfig{
62-
ThinkingConfig: &genai.ThinkingConfig{
63-
IncludeThoughts: true,
64-
ThinkingBudget: ptr.To(int32(1000)),
65-
},
66-
},
6754
SafetySettings: []*genai.SafetySetting{
6855
{
6956
Category: genai.HarmCategoryHarassment,
@@ -73,55 +60,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
7360
},
7461
},
7562
},
76-
{
77-
name: "Request with multiple vendor fields",
78-
jsonData: []byte(`{
79-
"model": "claude-3",
80-
"messages": [
81-
{
82-
"role": "user",
83-
"content": "Multiple vendors test"
84-
}
85-
],
86-
"generationConfig": {
87-
"thinkingConfig": {
88-
"includeThoughts": true,
89-
"thinkingBudget": 1000
90-
}
91-
},
92-
"thinking": {
93-
"type": "enabled",
94-
"budget_tokens": 1000
95-
}
96-
}`),
97-
expected: &ChatCompletionRequest{
98-
Model: "claude-3",
99-
Messages: []ChatCompletionMessageParamUnion{
100-
{
101-
OfUser: &ChatCompletionUserMessageParam{
102-
Role: ChatMessageRoleUser,
103-
Content: StringOrUserRoleContentUnion{Value: "Multiple vendors test"},
104-
},
105-
},
106-
},
107-
AnthropicVendorFields: &AnthropicVendorFields{
108-
Thinking: &anthropic.ThinkingConfigParamUnion{
109-
OfEnabled: &anthropic.ThinkingConfigEnabledParam{
110-
BudgetTokens: 1000,
111-
Type: "enabled",
112-
},
113-
},
114-
},
115-
GCPVertexAIVendorFields: &GCPVertexAIVendorFields{
116-
GenerationConfig: &GCPVertexAIGenerationConfig{
117-
ThinkingConfig: &genai.ThinkingConfig{
118-
IncludeThoughts: true,
119-
ThinkingBudget: ptr.To(int32(1000)),
120-
},
121-
},
122-
},
123-
},
124-
},
12563
{
12664
name: "Request without vendor fields",
12765
jsonData: []byte(`{
@@ -252,45 +190,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
252190
},
253191
},
254192
},
255-
{
256-
name: "Request with both detail and thinkingConfig fields",
257-
jsonData: []byte(`{
258-
"model": "gemini-1.5-pro",
259-
"messages": [
260-
{
261-
"role": "user",
262-
"content": "Test with both detail and thinking config"
263-
}
264-
],
265-
"generationConfig": {
266-
"media_resolution": "medium",
267-
"thinkingConfig": {
268-
"includeThoughts": true,
269-
"thinkingBudget": 500
270-
}
271-
}
272-
}`),
273-
expected: &ChatCompletionRequest{
274-
Model: "gemini-1.5-pro",
275-
Messages: []ChatCompletionMessageParamUnion{
276-
{
277-
OfUser: &ChatCompletionUserMessageParam{
278-
Role: ChatMessageRoleUser,
279-
Content: StringOrUserRoleContentUnion{Value: "Test with both detail and thinking config"},
280-
},
281-
},
282-
},
283-
GCPVertexAIVendorFields: &GCPVertexAIVendorFields{
284-
GenerationConfig: &GCPVertexAIGenerationConfig{
285-
MediaResolution: "medium",
286-
ThinkingConfig: &genai.ThinkingConfig{
287-
IncludeThoughts: true,
288-
ThinkingBudget: ptr.To(int32(500)),
289-
},
290-
},
291-
},
292-
},
293-
},
294193
}
295194

296195
for _, tt := range tests {

internal/translator/openai_awsbedrock.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,29 @@ type openAIToAWSBedrockTranslatorV1ChatCompletion struct {
4848
activeToolStream bool
4949
}
5050

51+
func getAwsBedrockThinkingMap(tu *openai.ThinkingUnion) map[string]any {
52+
if tu == nil {
53+
return nil
54+
}
55+
56+
resultMap := make(map[string]any)
57+
58+
if tu.OfEnabled != nil {
59+
reasoningConfigMap := map[string]any{
60+
"type": "enabled",
61+
"budget_tokens": tu.OfEnabled.BudgetTokens,
62+
}
63+
resultMap["thinking"] = reasoningConfigMap
64+
} else if tu.OfDisabled != nil {
65+
reasoningConfigMap := map[string]any{
66+
"type": "disabled",
67+
}
68+
resultMap["thinking"] = reasoningConfigMap
69+
}
70+
71+
return resultMap
72+
}
73+
5174
// RequestBody implements [OpenAIChatCompletionTranslator.RequestBody].
5275
func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, _ bool) (
5376
newHeaders []internalapi.Header, newBody []byte, err error,
@@ -83,12 +106,12 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(_ []byte, ope
83106
bedrockReq.InferenceConfig.StopSequences = openAIReq.Stop.OfStringArray
84107
}
85108

86-
// Handle Anthropic vendor fields if present. Currently only supports thinking fields.
87-
if openAIReq.AnthropicVendorFields != nil && openAIReq.Thinking != nil {
109+
// Handle thinking config
110+
if openAIReq.Thinking != nil {
88111
if bedrockReq.AdditionalModelRequestFields == nil {
89112
bedrockReq.AdditionalModelRequestFields = make(map[string]interface{})
90113
}
91-
bedrockReq.AdditionalModelRequestFields["thinking"] = openAIReq.Thinking
114+
bedrockReq.AdditionalModelRequestFields = getAwsBedrockThinkingMap(openAIReq.Thinking)
92115
}
93116

94117
// Convert Chat Completion messages.

internal/translator/openai_awsbedrock_test.go

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"testing"
1818
"time"
1919

20-
"github.com/anthropics/anthropic-sdk-go"
2120
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
2221
"github.com/google/go-cmp/cmp"
2322
"github.com/google/go-cmp/cmp/cmpopts"
@@ -897,11 +896,10 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
897896
},
898897
},
899898
},
900-
AnthropicVendorFields: &openai.AnthropicVendorFields{
901-
Thinking: &anthropic.ThinkingConfigParamUnion{
902-
OfEnabled: &anthropic.ThinkingConfigEnabledParam{
903-
BudgetTokens: int64(1024),
904-
},
899+
Thinking: &openai.ThinkingUnion{
900+
OfEnabled: &openai.ThinkingEnabled{
901+
BudgetTokens: int64(1024),
902+
Type: "enabled",
905903
},
906904
},
907905
},
@@ -1115,12 +1113,10 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
11151113
},
11161114
},
11171115
},
1118-
AnthropicVendorFields: &openai.AnthropicVendorFields{
1119-
Thinking: &anthropic.ThinkingConfigParamUnion{
1120-
OfEnabled: &anthropic.ThinkingConfigEnabledParam{
1121-
Type: "enabled",
1122-
BudgetTokens: 1024,
1123-
},
1116+
Thinking: &openai.ThinkingUnion{
1117+
OfEnabled: &openai.ThinkingEnabled{
1118+
Type: "enabled",
1119+
BudgetTokens: 1024,
11241120
},
11251121
},
11261122
},
@@ -1149,11 +1145,9 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
11491145
},
11501146
},
11511147
},
1152-
AnthropicVendorFields: &openai.AnthropicVendorFields{
1153-
Thinking: &anthropic.ThinkingConfigParamUnion{
1154-
OfDisabled: &anthropic.ThinkingConfigDisabledParam{
1155-
Type: "disabled",
1156-
},
1148+
Thinking: &openai.ThinkingUnion{
1149+
OfDisabled: &openai.ThinkingDisabled{
1150+
Type: "disabled",
11571151
},
11581152
},
11591153
},

0 commit comments

Comments
 (0)