Skip to content

Commit 6e9eca0

Browse files
committed
Support Gemini thinking budget in inference API
Adding support for configuring the thinkingBudget for Gemini 2.5 models when creating chat completion inference endpoints. The thinking_budget field is nested inside the thinking_config object in service_settings. - Added ThinkingConfig class to contain the thinking_budget field. This results in a less flat structure for the PUT _inference/chat_completion/ call but will make adding support for include_thoughts easier in future - Added extractOptionalInteger() method to ServiceUtils - Unit tests for ThinkingConfig class - Updated existing tests to account for the new object and field These changes enable elastic/kibana#227590 to be completed
1 parent 89564db commit 6e9eca0

File tree

13 files changed

+517
-49
lines changed

13 files changed

+517
-49
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ static TransportVersion def(int id) {
369369
public static final TransportVersion SCRIPT_RESCORER = def(9_143_0_00);
370370
public static final TransportVersion ESQL_LOOKUP_OPERATOR_EMITTED_ROWS = def(9_144_0_00);
371371
public static final TransportVersion ALLOCATION_DECISION_NOT_PREFERRED = def(9_145_0_00);
372+
public static final TransportVersion GEMINI_THINKING_BUDGET_ADDED = def(9_146_0_00);
372373

373374
/*
374375
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,25 @@ public static Integer extractOptionalPositiveInteger(
780780
String settingName,
781781
String scope,
782782
ValidationException validationException
783+
) {
784+
return extractOptionalInteger(map, settingName, scope, validationException, true);
785+
}
786+
787+
public static Integer extractOptionalInteger(
788+
Map<String, Object> map,
789+
String settingName,
790+
String scope,
791+
ValidationException validationException
792+
) {
793+
return extractOptionalInteger(map, settingName, scope, validationException, false);
794+
}
795+
796+
private static Integer extractOptionalInteger(
797+
Map<String, Object> map,
798+
String settingName,
799+
String scope,
800+
ValidationException validationException,
801+
boolean mustBePositive
783802
) {
784803
int initialValidationErrorCount = validationException.validationErrors().size();
785804
Integer optionalField = ServiceUtils.removeAsType(map, settingName, Integer.class, validationException);
@@ -788,7 +807,7 @@ public static Integer extractOptionalPositiveInteger(
788807
return null;
789808
}
790809

791-
if (optionalField != null && optionalField <= 0) {
810+
if (optionalField != null && mustBePositive && optionalField <= 0) {
792811
validationException.addValidationError(ServiceUtils.mustBeAPositiveIntegerErrorMessage(settingName, scope, optionalField));
793812
return null;
794813
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionM
8080
originalModelServiceSettings.projectId(),
8181
originalModelServiceSettings.location(),
8282
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()),
83-
originalModelServiceSettings.rateLimitSettings()
83+
originalModelServiceSettings.rateLimitSettings(),
84+
originalModelServiceSettings.thinkingConfig()
8485
);
8586

8687
return new GoogleVertexAiChatCompletionModel(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.Strings;
1213
import org.elasticsearch.common.ValidationException;
1314
import org.elasticsearch.common.io.stream.StreamInput;
1415
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -44,12 +45,14 @@ public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXConten
4445
private final String projectId;
4546

4647
private final RateLimitSettings rateLimitSettings;
48+
private final ThinkingConfig thinkingConfig;
4749

4850
// https://cloud.google.com/vertex-ai/docs/quotas#eval-quotas
4951
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1000);
52+
private static final ThinkingConfig EMPTY_THINKING_CONFIG = new ThinkingConfig();
5053

5154
public GoogleVertexAiChatCompletionServiceSettings(StreamInput in) throws IOException {
52-
this(in.readString(), in.readString(), in.readString(), new RateLimitSettings(in));
55+
this(in.readString(), in.readString(), in.readString(), new RateLimitSettings(in), new ThinkingConfig(in));
5356
}
5457

5558
@Override
@@ -58,6 +61,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
5861
builder.field(LOCATION, location);
5962
builder.field(MODEL_ID, modelId);
6063
rateLimitSettings.toXContent(builder, params);
64+
thinkingConfig.toXContent(builder, params);
6165
return builder;
6266
}
6367

@@ -78,23 +82,34 @@ public static GoogleVertexAiChatCompletionServiceSettings fromMap(Map<String, Ob
7882
context
7983
);
8084

85+
// Extract optional thinkingConfig settings
86+
ThinkingConfig thinkingConfig = ThinkingConfig.of(
87+
map,
88+
EMPTY_THINKING_CONFIG,
89+
validationException,
90+
GoogleVertexAiService.NAME,
91+
context
92+
);
93+
8194
if (validationException.validationErrors().isEmpty() == false) {
8295
throw validationException;
8396
}
8497

85-
return new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings);
98+
return new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings, thinkingConfig);
8699
}
87100

88101
public GoogleVertexAiChatCompletionServiceSettings(
89102
String projectId,
90103
String location,
91104
String modelId,
92-
@Nullable RateLimitSettings rateLimitSettings
105+
@Nullable RateLimitSettings rateLimitSettings,
106+
@Nullable ThinkingConfig thinkingConfig
93107
) {
94108
this.projectId = projectId;
95109
this.location = location;
96110
this.modelId = modelId;
97111
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
112+
this.thinkingConfig = Objects.requireNonNullElse(thinkingConfig, EMPTY_THINKING_CONFIG);
98113
}
99114

100115
public String location() {
@@ -116,6 +131,15 @@ public String getWriteableName() {
116131
return NAME;
117132
}
118133

134+
@Override
135+
public RateLimitSettings rateLimitSettings() {
136+
return rateLimitSettings;
137+
}
138+
139+
public ThinkingConfig thinkingConfig() {
140+
return thinkingConfig;
141+
}
142+
119143
@Override
120144
public TransportVersion getMinimalSupportedVersion() {
121145
assert false : "should never be called when supportsVersion is used";
@@ -134,6 +158,7 @@ public void writeTo(StreamOutput out) throws IOException {
134158
out.writeString(location);
135159
out.writeString(modelId);
136160
rateLimitSettings.writeTo(out);
161+
thinkingConfig.writeTo(out);
137162
}
138163

139164
@Override
@@ -152,16 +177,17 @@ public boolean equals(Object o) {
152177
return Objects.equals(location, that.location)
153178
&& Objects.equals(modelId, that.modelId)
154179
&& Objects.equals(projectId, that.projectId)
155-
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
180+
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
181+
&& Objects.equals(thinkingConfig, that.thinkingConfig);
156182
}
157183

158184
@Override
159185
public int hashCode() {
160-
return Objects.hash(location, modelId, projectId, rateLimitSettings);
186+
return Objects.hash(location, modelId, projectId, rateLimitSettings, thinkingConfig);
161187
}
162188

163189
@Override
164-
public RateLimitSettings rateLimitSettings() {
165-
return rateLimitSettings;
190+
public String toString() {
191+
return Strings.toString(this);
166192
}
167193
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai.completion;
9+
10+
import org.elasticsearch.TransportVersions;
11+
import org.elasticsearch.common.Strings;
12+
import org.elasticsearch.common.ValidationException;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.common.io.stream.Writeable;
16+
import org.elasticsearch.inference.ModelConfigurations;
17+
import org.elasticsearch.xcontent.ToXContentFragment;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
20+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
21+
22+
import java.io.IOException;
23+
import java.util.Map;
24+
import java.util.Objects;
25+
26+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
27+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
28+
29+
/**
30+
* This class encapsulates the ThinkingConfig object contained within GenerationConfig. Only the thinkingBudget field is currently
31+
* supported, but the includeThoughts field may be added in the future
32+
*/
33+
public class ThinkingConfig implements Writeable, ToXContentFragment {
34+
public static final String THINKING_CONFIG_FIELD = "thinking_config";
35+
public static final String THINKING_BUDGET_FIELD = "thinking_budget";
36+
37+
private final Integer thinkingBudget;
38+
39+
/**
40+
* Constructor for an empty {@code ThinkingConfig}
41+
*/
42+
public ThinkingConfig() {
43+
this.thinkingBudget = null;
44+
}
45+
46+
public ThinkingConfig(Integer thinkingBudget) {
47+
this.thinkingBudget = thinkingBudget;
48+
}
49+
50+
public ThinkingConfig(StreamInput in) throws IOException {
51+
if (in.getTransportVersion().onOrAfter(TransportVersions.GEMINI_THINKING_BUDGET_ADDED)) {
52+
thinkingBudget = in.readOptionalVInt();
53+
} else {
54+
thinkingBudget = null;
55+
}
56+
}
57+
58+
public static ThinkingConfig of(
59+
Map<String, Object> map,
60+
ThinkingConfig defaultValue,
61+
ValidationException validationException,
62+
String serviceName,
63+
ConfigurationParseContext context
64+
) {
65+
Map<String, Object> thinkingConfigSettings = removeFromMapOrDefaultEmpty(map, THINKING_CONFIG_FIELD);
66+
Integer thinkingBudget = ServiceUtils.extractOptionalInteger(
67+
thinkingConfigSettings,
68+
THINKING_BUDGET_FIELD,
69+
ModelConfigurations.SERVICE_SETTINGS,
70+
validationException
71+
);
72+
73+
if (ConfigurationParseContext.isRequestContext(context)) {
74+
throwIfNotEmptyMap(thinkingConfigSettings, serviceName);
75+
}
76+
77+
return thinkingBudget == null ? defaultValue : new ThinkingConfig(thinkingBudget);
78+
}
79+
80+
public boolean isEmpty() {
81+
return thinkingBudget == null;
82+
}
83+
84+
public Integer getThinkingBudget() {
85+
return thinkingBudget;
86+
}
87+
88+
@Override
89+
public void writeTo(StreamOutput out) throws IOException {
90+
if (out.getTransportVersion().onOrAfter(TransportVersions.GEMINI_THINKING_BUDGET_ADDED)) {
91+
out.writeOptionalVInt(thinkingBudget);
92+
}
93+
}
94+
95+
@Override
96+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
97+
if (thinkingBudget != null) {
98+
builder.startObject(THINKING_CONFIG_FIELD);
99+
builder.field(THINKING_BUDGET_FIELD, thinkingBudget);
100+
builder.endObject();
101+
}
102+
return builder;
103+
}
104+
105+
@Override
106+
public boolean equals(Object o) {
107+
if (o == null || getClass() != o.getClass()) return false;
108+
ThinkingConfig that = (ThinkingConfig) o;
109+
return Objects.equals(thinkingBudget, that.thinkingBudget);
110+
}
111+
112+
@Override
113+
public int hashCode() {
114+
return Objects.hashCode(thinkingBudget);
115+
}
116+
117+
@Override
118+
public String toString() {
119+
return Strings.toString(this);
120+
}
121+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ public GoogleVertexAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatIn
3737
public HttpRequest createHttpRequest() {
3838
HttpPost httpPost = new HttpPost(uri);
3939

40-
var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput);
40+
var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(
41+
unifiedChatInput,
42+
model.getServiceSettings().thinkingConfig()
43+
);
4144

4245
ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8));
4346
httpPost.setEntity(byteEntity);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.xcontent.XContentParserConfiguration;
2020
import org.elasticsearch.xcontent.XContentType;
2121
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
22+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig;
2223

2324
import java.io.IOException;
2425
import java.util.Locale;
@@ -37,6 +38,8 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont
3738
private static final String TEMPERATURE = "temperature";
3839
private static final String MAX_OUTPUT_TOKENS = "maxOutputTokens";
3940
private static final String TOP_P = "topP";
41+
private static final String THINKING_CONFIG = "thinkingConfig";
42+
private static final String THINKING_BUDGET = "thinkingBudget";
4043

4144
private static final String TOOLS = "tools";
4245
private static final String FUNCTION_DECLARATIONS = "functionDeclarations";
@@ -56,6 +59,7 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont
5659
private static final String FUNCTION_CALL_ARGS = "args";
5760

5861
private final UnifiedChatInput unifiedChatInput;
62+
private final ThinkingConfig thinkingConfig;
5963

6064
private static final String USER_ROLE = "user";
6165
private static final String MODEL_ROLE = "model";
@@ -66,8 +70,9 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont
6670

6771
private static final String SYSTEM_INSTRUCTION = "systemInstruction";
6872

69-
public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) {
73+
public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, ThinkingConfig thinkingConfig) {
7074
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
75+
this.thinkingConfig = Objects.requireNonNull(thinkingConfig);
7176
}
7277

7378
private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) {
@@ -316,7 +321,8 @@ private void buildGenerationConfig(XContentBuilder builder) throws IOException {
316321
boolean hasAnyConfig = request.stop() != null
317322
|| request.temperature() != null
318323
|| request.maxCompletionTokens() != null
319-
|| request.topP() != null;
324+
|| request.topP() != null
325+
|| thinkingConfig.isEmpty() == false;
320326

321327
if (hasAnyConfig == false) {
322328
return;
@@ -336,6 +342,11 @@ private void buildGenerationConfig(XContentBuilder builder) throws IOException {
336342
if (request.topP() != null) {
337343
builder.field(TOP_P, request.topP());
338344
}
345+
if (thinkingConfig.isEmpty() == false) {
346+
builder.startObject(THINKING_CONFIG);
347+
builder.field(THINKING_BUDGET, thinkingConfig.getThinkingBudget());
348+
builder.endObject();
349+
}
339350

340351
builder.endObject();
341352
}

0 commit comments

Comments
 (0)