Skip to content

Commit 1a1954f

Browse files
authored
[ML] Support Gemini thinking budget in inference API (#133599)
Adding support for configuring the thinkingBudget for Gemini 2.5 models when creating chat completion inference endpoints. The thinking_budget field is nested inside the thinking_config object in task_settings. These changes enable elastic/kibana#227590 to be completed - Added ThinkingConfig class to contain the thinking_budget field. This results in a less flat structure for the PUT _inference/chat_completion/ call but will make adding support for include_thoughts easier in future - Added extractOptionalInteger() method to ServiceUtils - Unit tests for ThinkingConfig class - Add test coverage for ServiceUtils.extractOptionalPositiveInteger() and extractOptionalInteger() - Updated existing tests to account for the new object and field
1 parent 2794b8d commit 1a1954f

18 files changed

+857
-53
lines changed

docs/changelog/133599.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 133599
2+
summary: Support Gemini thinking budget in inference API
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ static TransportVersion def(int id) {
359359
public static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = def(9_150_0_00);
360360
public static final TransportVersion ESQL_LOOKUP_JOIN_PRE_JOIN_FILTER = def(9_151_0_00);
361361
public static final TransportVersion INFERENCE_API_DISABLE_EIS_RATE_LIMITING = def(9_152_0_00);
362+
public static final TransportVersion GEMINI_THINKING_BUDGET_ADDED = def(9_153_0_00);
362363

363364
/*
364365
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
9090
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
9191
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings;
92+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings;
9293
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
9394
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
9495
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
@@ -570,6 +571,14 @@ private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry
570571
)
571572
);
572573

574+
namedWriteables.add(
575+
new NamedWriteableRegistry.Entry(
576+
TaskSettings.class,
577+
GoogleVertexAiChatCompletionTaskSettings.NAME,
578+
GoogleVertexAiChatCompletionTaskSettings::new
579+
)
580+
);
581+
573582
}
574583

575584
private static void addInternalNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,25 @@ public static Integer extractOptionalPositiveInteger(
799799
String settingName,
800800
String scope,
801801
ValidationException validationException
802+
) {
803+
return extractOptionalInteger(map, settingName, scope, validationException, true);
804+
}
805+
806+
public static Integer extractOptionalInteger(
807+
Map<String, Object> map,
808+
String settingName,
809+
String scope,
810+
ValidationException validationException
811+
) {
812+
return extractOptionalInteger(map, settingName, scope, validationException, false);
813+
}
814+
815+
private static Integer extractOptionalInteger(
816+
Map<String, Object> map,
817+
String settingName,
818+
String scope,
819+
ValidationException validationException,
820+
boolean mustBePositive
802821
) {
803822
int initialValidationErrorCount = validationException.validationErrors().size();
804823
Integer optionalField = ServiceUtils.removeAsType(map, settingName, Integer.class, validationException);
@@ -807,7 +826,7 @@ public static Integer extractOptionalPositiveInteger(
807826
return null;
808827
}
809828

810-
if (optionalField != null && optionalField <= 0) {
829+
if (optionalField != null && mustBePositive && optionalField <= 0) {
811830
validationException.addValidationError(ServiceUtils.mustBeAPositiveIntegerErrorMessage(settingName, scope, optionalField));
812831
return null;
813832
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,14 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Obje
7373

7474
@Override
7575
public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings) {
76+
var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, taskSettings);
7677
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
7778

7879
var manager = new GenericRequestManager<>(
7980
serviceComponents.threadPool(),
80-
model,
81+
overriddenModel,
8182
CHAT_COMPLETION_HANDLER,
82-
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
83+
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel),
8384
ChatCompletionInput.class
8485
);
8586

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.apache.http.client.utils.URIBuilder;
1111
import org.elasticsearch.core.Nullable;
12-
import org.elasticsearch.inference.EmptyTaskSettings;
1312
import org.elasticsearch.inference.ModelConfigurations;
1413
import org.elasticsearch.inference.ModelSecrets;
1514
import org.elasticsearch.inference.TaskType;
@@ -47,7 +46,7 @@ public GoogleVertexAiChatCompletionModel(
4746
taskType,
4847
service,
4948
GoogleVertexAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
50-
new EmptyTaskSettings(),
49+
GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettings),
5150
GoogleVertexAiSecretSettings.fromMap(secrets)
5251
);
5352
}
@@ -57,7 +56,7 @@ public GoogleVertexAiChatCompletionModel(
5756
TaskType taskType,
5857
String service,
5958
GoogleVertexAiChatCompletionServiceSettings serviceSettings,
60-
EmptyTaskSettings taskSettings,
59+
GoogleVertexAiChatCompletionTaskSettings taskSettings,
6160
@Nullable GoogleVertexAiSecretSettings secrets
6261
) {
6362
super(
@@ -73,6 +72,14 @@ public GoogleVertexAiChatCompletionModel(
7372
}
7473
}
7574

75+
private GoogleVertexAiChatCompletionModel(
76+
GoogleVertexAiChatCompletionModel model,
77+
GoogleVertexAiChatCompletionTaskSettings taskSettings
78+
) {
79+
super(model, taskSettings);
80+
streamingURI = model.streamingURI();
81+
}
82+
7683
public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, UnifiedCompletionRequest request) {
7784
var originalModelServiceSettings = model.getServiceSettings();
7885

@@ -93,6 +100,26 @@ public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionM
93100
);
94101
}
95102

103+
/**
104+
* Overrides the task settings in the given model with the settings in the map. If no new settings are present or the provided settings
105+
* do not differ from those already in the model, returns the original model
106+
* @param model the model whose task settings will be overridden
107+
* @param taskSettingsMap the new task settings to use
108+
* @return a {@link GoogleVertexAiChatCompletionModel} with overridden {@link GoogleVertexAiChatCompletionTaskSettings}
109+
*/
110+
public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettingsMap) {
111+
if (taskSettingsMap == null || taskSettingsMap.isEmpty()) {
112+
return model;
113+
}
114+
115+
var requestTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettingsMap);
116+
if (requestTaskSettings.isEmpty() || model.getTaskSettings().equals(requestTaskSettings)) {
117+
return model;
118+
}
119+
var combinedTaskSettings = GoogleVertexAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings);
120+
return new GoogleVertexAiChatCompletionModel(model, combinedTaskSettings);
121+
}
122+
96123
@Override
97124
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
98125
return visitor.create(this, taskSettings);
@@ -109,8 +136,8 @@ public GoogleVertexAiChatCompletionServiceSettings getServiceSettings() {
109136
}
110137

111138
@Override
112-
public EmptyTaskSettings getTaskSettings() {
113-
return (EmptyTaskSettings) super.getTaskSettings();
139+
public GoogleVertexAiChatCompletionTaskSettings getTaskSettings() {
140+
return (GoogleVertexAiChatCompletionTaskSettings) super.getTaskSettings();
114141
}
115142

116143
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.Strings;
1213
import org.elasticsearch.common.ValidationException;
1314
import org.elasticsearch.common.io.stream.StreamInput;
1415
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -116,6 +117,11 @@ public String getWriteableName() {
116117
return NAME;
117118
}
118119

120+
@Override
121+
public RateLimitSettings rateLimitSettings() {
122+
return rateLimitSettings;
123+
}
124+
119125
@Override
120126
public TransportVersion getMinimalSupportedVersion() {
121127
assert false : "should never be called when supportsVersion is used";
@@ -161,7 +167,7 @@ public int hashCode() {
161167
}
162168

163169
@Override
164-
public RateLimitSettings rateLimitSettings() {
165-
return rateLimitSettings;
170+
public String toString() {
171+
return Strings.toString(this);
166172
}
167173
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai.completion;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.common.ValidationException;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.inference.TaskSettings;
17+
import org.elasticsearch.xcontent.XContentBuilder;
18+
19+
import java.io.IOException;
20+
import java.util.HashMap;
21+
import java.util.Map;
22+
import java.util.Objects;
23+
24+
public class GoogleVertexAiChatCompletionTaskSettings implements TaskSettings {
25+
public static final String NAME = "google_vertex_ai_chatcompletion_task_settings";
26+
27+
private final ThinkingConfig thinkingConfig;
28+
29+
public static final GoogleVertexAiChatCompletionTaskSettings EMPTY_SETTINGS = new GoogleVertexAiChatCompletionTaskSettings();
30+
private static final ThinkingConfig EMPTY_THINKING_CONFIG = new ThinkingConfig();
31+
32+
public GoogleVertexAiChatCompletionTaskSettings() {
33+
thinkingConfig = EMPTY_THINKING_CONFIG;
34+
}
35+
36+
public GoogleVertexAiChatCompletionTaskSettings(ThinkingConfig thinkingConfig) {
37+
this.thinkingConfig = Objects.requireNonNullElse(thinkingConfig, EMPTY_THINKING_CONFIG);
38+
}
39+
40+
public GoogleVertexAiChatCompletionTaskSettings(StreamInput in) throws IOException {
41+
thinkingConfig = new ThinkingConfig(in);
42+
}
43+
44+
public static GoogleVertexAiChatCompletionTaskSettings fromMap(Map<String, Object> taskSettings) {
45+
ValidationException validationException = new ValidationException();
46+
47+
// Extract optional thinkingConfig settings
48+
ThinkingConfig thinkingConfig = ThinkingConfig.fromMap(taskSettings, validationException);
49+
50+
if (validationException.validationErrors().isEmpty() == false) {
51+
throw validationException;
52+
}
53+
54+
return new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig);
55+
}
56+
57+
public static GoogleVertexAiChatCompletionTaskSettings of(
58+
GoogleVertexAiChatCompletionTaskSettings originalTaskSettings,
59+
GoogleVertexAiChatCompletionTaskSettings newTaskSettings
60+
) {
61+
ThinkingConfig thinkingConfig = newTaskSettings.thinkingConfig().isEmpty()
62+
? originalTaskSettings.thinkingConfig()
63+
: newTaskSettings.thinkingConfig();
64+
return new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig);
65+
}
66+
67+
public ThinkingConfig thinkingConfig() {
68+
return thinkingConfig;
69+
}
70+
71+
@Override
72+
public boolean isEmpty() {
73+
return thinkingConfig.isEmpty();
74+
}
75+
76+
@Override
77+
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
78+
GoogleVertexAiChatCompletionTaskSettings newTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap(
79+
new HashMap<>(newSettings)
80+
);
81+
return GoogleVertexAiChatCompletionTaskSettings.of(this, newTaskSettings);
82+
}
83+
84+
@Override
85+
public String getWriteableName() {
86+
return NAME;
87+
}
88+
89+
@Override
90+
public TransportVersion getMinimalSupportedVersion() {
91+
return TransportVersions.GEMINI_THINKING_BUDGET_ADDED;
92+
}
93+
94+
@Override
95+
public void writeTo(StreamOutput out) throws IOException {
96+
thinkingConfig.writeTo(out);
97+
}
98+
99+
@Override
100+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
101+
builder.startObject();
102+
thinkingConfig.toXContent(builder, params);
103+
builder.endObject();
104+
return builder;
105+
}
106+
107+
@Override
108+
public boolean equals(Object o) {
109+
if (o == null || getClass() != o.getClass()) return false;
110+
GoogleVertexAiChatCompletionTaskSettings that = (GoogleVertexAiChatCompletionTaskSettings) o;
111+
return Objects.equals(thinkingConfig, that.thinkingConfig);
112+
}
113+
114+
@Override
115+
public int hashCode() {
116+
return Objects.hashCode(thinkingConfig);
117+
}
118+
119+
@Override
120+
public String toString() {
121+
return Strings.toString(this);
122+
}
123+
}

0 commit comments

Comments
 (0)