Skip to content

Commit ade279d

Browse files
committed
feat(gemini): add Gemini thinking configuration support and metadata normalization - Introduced logic to parse and apply `thinkingBudget` and `include_thoughts` configurations from metadata. - Enhanced request handling to include normalized Gemini model metadata, preserving the original model identifier. - Updated Gemini and Gemini-CLI executors to apply thinking configuration based on metadata overrides. - Refactored handlers to support metadata extraction and cloning during request preparation.
1 parent 9c5ac29 commit ade279d

File tree

4 files changed

+257
-6
lines changed

4 files changed

+257
-6
lines changed

internal/runtime/executor/gemini_cli_executor.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
6060

6161
from := opts.SourceFormat
6262
to := sdktranslator.FromString("gemini-cli")
63+
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
6364
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
65+
if hasOverride {
66+
basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride)
67+
}
6468
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
6569

6670
action := "generateContent"
@@ -149,7 +153,11 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
149153

150154
from := opts.SourceFormat
151155
to := sdktranslator.FromString("gemini-cli")
156+
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
152157
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
158+
if hasOverride {
159+
basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride)
160+
}
153161
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
154162

155163
projectID := strings.TrimSpace(stringValue(auth.Metadata, "project_id"))
@@ -292,8 +300,12 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
292300
var lastStatus int
293301
var lastBody []byte
294302

303+
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
295304
for _, attemptModel := range models {
296305
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
306+
if hasOverride {
307+
payload = util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
308+
}
297309
payload = deleteJSONField(payload, "project")
298310
payload = deleteJSONField(payload, "model")
299311
payload = disableGeminiThinkingConfig(payload, attemptModel)

internal/runtime/executor/gemini_executor.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
7777
from := opts.SourceFormat
7878
to := sdktranslator.FromString("gemini")
7979
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
80+
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok {
81+
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
82+
}
8083
body = disableGeminiThinkingConfig(body, req.Model)
8184
body = fixGeminiImageAspectRatio(req.Model, body)
8285

@@ -136,6 +139,9 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
136139
from := opts.SourceFormat
137140
to := sdktranslator.FromString("gemini")
138141
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
142+
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok {
143+
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
144+
}
139145
body = disableGeminiThinkingConfig(body, req.Model)
140146
body = fixGeminiImageAspectRatio(req.Model, body)
141147

@@ -208,6 +214,9 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
208214
from := opts.SourceFormat
209215
to := sdktranslator.FromString("gemini")
210216
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
217+
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok {
218+
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
219+
}
211220
translatedReq = disableGeminiThinkingConfig(translatedReq, req.Model)
212221
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
213222
respCtx := context.WithValue(ctx, "alt", opts.Alt)

internal/util/gemini_thinking.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package util
2+
3+
import (
4+
"encoding/json"
5+
"strconv"
6+
"strings"
7+
8+
"github.com/tidwall/sjson"
9+
)
10+
11+
const (
12+
GeminiThinkingBudgetMetadataKey = "gemini_thinking_budget"
13+
GeminiIncludeThoughtsMetadataKey = "gemini_include_thoughts"
14+
GeminiOriginalModelMetadataKey = "gemini_original_model"
15+
)
16+
17+
func ParseGeminiThinkingSuffix(model string) (string, *int, *bool, bool) {
18+
if model == "" {
19+
return model, nil, nil, false
20+
}
21+
lower := strings.ToLower(model)
22+
if !strings.HasPrefix(lower, "gemini-") {
23+
return model, nil, nil, false
24+
}
25+
26+
if strings.HasSuffix(lower, "-nothinking") {
27+
base := model[:len(model)-len("-nothinking")]
28+
budgetValue := 0
29+
if strings.HasPrefix(lower, "gemini-2.5-pro") {
30+
budgetValue = 128
31+
}
32+
include := false
33+
return base, &budgetValue, &include, true
34+
}
35+
36+
idx := strings.LastIndex(lower, "-thinking-")
37+
if idx == -1 {
38+
return model, nil, nil, false
39+
}
40+
41+
digits := model[idx+len("-thinking-"):]
42+
if digits == "" {
43+
return model, nil, nil, false
44+
}
45+
end := len(digits)
46+
for i := 0; i < len(digits); i++ {
47+
if digits[i] < '0' || digits[i] > '9' {
48+
end = i
49+
break
50+
}
51+
}
52+
if end == 0 {
53+
return model, nil, nil, false
54+
}
55+
valueStr := digits[:end]
56+
value, err := strconv.Atoi(valueStr)
57+
if err != nil {
58+
return model, nil, nil, false
59+
}
60+
base := model[:idx]
61+
budgetValue := value
62+
return base, &budgetValue, nil, true
63+
}
64+
65+
func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte {
66+
if budget == nil && includeThoughts == nil {
67+
return body
68+
}
69+
updated := body
70+
if budget != nil {
71+
valuePath := "generationConfig.thinkingConfig.thinkingBudget"
72+
rewritten, err := sjson.SetBytes(updated, valuePath, *budget)
73+
if err == nil {
74+
updated = rewritten
75+
}
76+
}
77+
if includeThoughts != nil {
78+
valuePath := "generationConfig.thinkingConfig.include_thoughts"
79+
rewritten, err := sjson.SetBytes(updated, valuePath, *includeThoughts)
80+
if err == nil {
81+
updated = rewritten
82+
}
83+
}
84+
return updated
85+
}
86+
87+
func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte {
88+
if budget == nil && includeThoughts == nil {
89+
return body
90+
}
91+
updated := body
92+
if budget != nil {
93+
valuePath := "request.generationConfig.thinkingConfig.thinkingBudget"
94+
rewritten, err := sjson.SetBytes(updated, valuePath, *budget)
95+
if err == nil {
96+
updated = rewritten
97+
}
98+
}
99+
if includeThoughts != nil {
100+
valuePath := "request.generationConfig.thinkingConfig.include_thoughts"
101+
rewritten, err := sjson.SetBytes(updated, valuePath, *includeThoughts)
102+
if err == nil {
103+
updated = rewritten
104+
}
105+
}
106+
return updated
107+
}
108+
109+
func GeminiThinkingFromMetadata(metadata map[string]any) (*int, *bool, bool) {
110+
if len(metadata) == 0 {
111+
return nil, nil, false
112+
}
113+
var (
114+
budgetPtr *int
115+
includePtr *bool
116+
matched bool
117+
)
118+
if rawBudget, ok := metadata[GeminiThinkingBudgetMetadataKey]; ok {
119+
switch v := rawBudget.(type) {
120+
case int:
121+
budget := v
122+
budgetPtr = &budget
123+
matched = true
124+
case int32:
125+
budget := int(v)
126+
budgetPtr = &budget
127+
matched = true
128+
case int64:
129+
budget := int(v)
130+
budgetPtr = &budget
131+
matched = true
132+
case float64:
133+
budget := int(v)
134+
budgetPtr = &budget
135+
matched = true
136+
case json.Number:
137+
if val, err := v.Int64(); err == nil {
138+
budget := int(val)
139+
budgetPtr = &budget
140+
matched = true
141+
}
142+
}
143+
}
144+
if rawInclude, ok := metadata[GeminiIncludeThoughtsMetadataKey]; ok {
145+
switch v := rawInclude.(type) {
146+
case bool:
147+
include := v
148+
includePtr = &include
149+
matched = true
150+
case string:
151+
if parsed, err := strconv.ParseBool(v); err == nil {
152+
include := parsed
153+
includePtr = &include
154+
matched = true
155+
}
156+
case json.Number:
157+
if val, err := v.Int64(); err == nil {
158+
include := val != 0
159+
includePtr = &include
160+
matched = true
161+
}
162+
case int:
163+
include := v != 0
164+
includePtr = &include
165+
matched = true
166+
case int32:
167+
include := v != 0
168+
includePtr = &include
169+
matched = true
170+
case int64:
171+
include := v != 0
172+
includePtr = &include
173+
matched = true
174+
case float64:
175+
include := v != 0
176+
includePtr = &include
177+
matched = true
178+
}
179+
}
180+
return budgetPtr, includePtr, matched
181+
}

sdk/api/handlers/handlers.go

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,20 +133,27 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
133133
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
134134
// This path is the only supported execution route.
135135
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
136-
providers := util.GetProviderName(modelName)
136+
normalizedModel, metadata := normalizeModelMetadata(modelName)
137+
providers := util.GetProviderName(normalizedModel)
137138
if len(providers) == 0 {
138139
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
139140
}
140141
req := coreexecutor.Request{
141-
Model: modelName,
142+
Model: normalizedModel,
142143
Payload: cloneBytes(rawJSON),
143144
}
145+
if cloned := cloneMetadata(metadata); cloned != nil {
146+
req.Metadata = cloned
147+
}
144148
opts := coreexecutor.Options{
145149
Stream: false,
146150
Alt: alt,
147151
OriginalRequest: cloneBytes(rawJSON),
148152
SourceFormat: sdktranslator.FromString(handlerType),
149153
}
154+
if cloned := cloneMetadata(metadata); cloned != nil {
155+
opts.Metadata = cloned
156+
}
150157
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
151158
if err != nil {
152159
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
@@ -157,20 +164,27 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
157164
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
158165
// This path is the only supported execution route.
159166
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
160-
providers := util.GetProviderName(modelName)
167+
normalizedModel, metadata := normalizeModelMetadata(modelName)
168+
providers := util.GetProviderName(normalizedModel)
161169
if len(providers) == 0 {
162170
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
163171
}
164172
req := coreexecutor.Request{
165-
Model: modelName,
173+
Model: normalizedModel,
166174
Payload: cloneBytes(rawJSON),
167175
}
176+
if cloned := cloneMetadata(metadata); cloned != nil {
177+
req.Metadata = cloned
178+
}
168179
opts := coreexecutor.Options{
169180
Stream: false,
170181
Alt: alt,
171182
OriginalRequest: cloneBytes(rawJSON),
172183
SourceFormat: sdktranslator.FromString(handlerType),
173184
}
185+
if cloned := cloneMetadata(metadata); cloned != nil {
186+
opts.Metadata = cloned
187+
}
174188
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
175189
if err != nil {
176190
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
@@ -181,23 +195,30 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
181195
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
182196
// This path is the only supported execution route.
183197
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
184-
providers := util.GetProviderName(modelName)
198+
normalizedModel, metadata := normalizeModelMetadata(modelName)
199+
providers := util.GetProviderName(normalizedModel)
185200
if len(providers) == 0 {
186201
errChan := make(chan *interfaces.ErrorMessage, 1)
187202
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
188203
close(errChan)
189204
return nil, errChan
190205
}
191206
req := coreexecutor.Request{
192-
Model: modelName,
207+
Model: normalizedModel,
193208
Payload: cloneBytes(rawJSON),
194209
}
210+
if cloned := cloneMetadata(metadata); cloned != nil {
211+
req.Metadata = cloned
212+
}
195213
opts := coreexecutor.Options{
196214
Stream: true,
197215
Alt: alt,
198216
OriginalRequest: cloneBytes(rawJSON),
199217
SourceFormat: sdktranslator.FromString(handlerType),
200218
}
219+
if cloned := cloneMetadata(metadata); cloned != nil {
220+
opts.Metadata = cloned
221+
}
201222
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
202223
if err != nil {
203224
errChan := make(chan *interfaces.ErrorMessage, 1)
@@ -232,6 +253,34 @@ func cloneBytes(src []byte) []byte {
232253
return dst
233254
}
234255

256+
func normalizeModelMetadata(modelName string) (string, map[string]any) {
257+
baseModel, budget, include, matched := util.ParseGeminiThinkingSuffix(modelName)
258+
if !matched {
259+
return baseModel, nil
260+
}
261+
metadata := map[string]any{
262+
util.GeminiOriginalModelMetadataKey: modelName,
263+
}
264+
if budget != nil {
265+
metadata[util.GeminiThinkingBudgetMetadataKey] = *budget
266+
}
267+
if include != nil {
268+
metadata[util.GeminiIncludeThoughtsMetadataKey] = *include
269+
}
270+
return baseModel, metadata
271+
}
272+
273+
func cloneMetadata(src map[string]any) map[string]any {
274+
if len(src) == 0 {
275+
return nil
276+
}
277+
dst := make(map[string]any, len(src))
278+
for k, v := range src {
279+
dst[k] = v
280+
}
281+
return dst
282+
}
283+
235284
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
236285
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
237286
status := http.StatusInternalServerError

0 commit comments

Comments
 (0)