Skip to content

Commit 546f5ad

Browse files
committed
Move thinking config settings into task settings
1 parent 4171c69 commit 546f5ad

File tree

11 files changed

+379
-102
lines changed

11 files changed

+379
-102
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
9090
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
9191
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings;
92+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings;
9293
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
9394
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
9495
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
@@ -570,6 +571,14 @@ private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry
570571
)
571572
);
572573

574+
namedWriteables.add(
575+
new NamedWriteableRegistry.Entry(
576+
TaskSettings.class,
577+
GoogleVertexAiChatCompletionTaskSettings.NAME,
578+
GoogleVertexAiChatCompletionTaskSettings::new
579+
)
580+
);
581+
573582
}
574583

575584
private static void addInternalNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,14 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Obje
7373

7474
@Override
7575
public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings) {
76+
var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, taskSettings);
7677
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
7778

7879
var manager = new GenericRequestManager<>(
7980
serviceComponents.threadPool(),
80-
model,
81+
overriddenModel,
8182
CHAT_COMPLETION_HANDLER,
82-
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
83+
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel),
8384
ChatCompletionInput.class
8485
);
8586

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.apache.http.client.utils.URIBuilder;
1111
import org.elasticsearch.core.Nullable;
12-
import org.elasticsearch.inference.EmptyTaskSettings;
1312
import org.elasticsearch.inference.ModelConfigurations;
1413
import org.elasticsearch.inference.ModelSecrets;
1514
import org.elasticsearch.inference.TaskType;
@@ -47,7 +46,7 @@ public GoogleVertexAiChatCompletionModel(
4746
taskType,
4847
service,
4948
GoogleVertexAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
50-
new EmptyTaskSettings(),
49+
GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettings),
5150
GoogleVertexAiSecretSettings.fromMap(secrets)
5251
);
5352
}
@@ -57,7 +56,7 @@ public GoogleVertexAiChatCompletionModel(
5756
TaskType taskType,
5857
String service,
5958
GoogleVertexAiChatCompletionServiceSettings serviceSettings,
60-
EmptyTaskSettings taskSettings,
59+
GoogleVertexAiChatCompletionTaskSettings taskSettings,
6160
@Nullable GoogleVertexAiSecretSettings secrets
6261
) {
6362
super(
@@ -73,15 +72,22 @@ public GoogleVertexAiChatCompletionModel(
7372
}
7473
}
7574

75+
private GoogleVertexAiChatCompletionModel(
76+
GoogleVertexAiChatCompletionModel model,
77+
GoogleVertexAiChatCompletionTaskSettings taskSettings
78+
) {
79+
super(model, taskSettings);
80+
streamingURI = model.streamingURI();
81+
}
82+
7683
public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, UnifiedCompletionRequest request) {
7784
var originalModelServiceSettings = model.getServiceSettings();
7885

7986
var newServiceSettings = new GoogleVertexAiChatCompletionServiceSettings(
8087
originalModelServiceSettings.projectId(),
8188
originalModelServiceSettings.location(),
8289
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()),
83-
originalModelServiceSettings.rateLimitSettings(),
84-
originalModelServiceSettings.thinkingConfig()
90+
originalModelServiceSettings.rateLimitSettings()
8591
);
8692

8793
return new GoogleVertexAiChatCompletionModel(
@@ -94,6 +100,26 @@ public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionM
94100
);
95101
}
96102

103+
/**
104+
* Overrides the task settings in the given model with the settings in the map. If no new settings are present or the provided settings
105+
* do not differ from those already in the model, returns the original model
106+
* @param model the model whose task settings will be overridden
107+
* @param taskSettingsMap the new task settings to use
108+
* @return a {@link GoogleVertexAiChatCompletionModel} with overridden {@link GoogleVertexAiChatCompletionTaskSettings}
109+
*/
110+
public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettingsMap) {
111+
if (taskSettingsMap == null || taskSettingsMap.isEmpty()) {
112+
return model;
113+
}
114+
115+
var requestTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettingsMap);
116+
if (requestTaskSettings.isEmpty() || model.getTaskSettings().equals(requestTaskSettings)) {
117+
return model;
118+
}
119+
var combinedTaskSettings = GoogleVertexAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings);
120+
return new GoogleVertexAiChatCompletionModel(model, combinedTaskSettings);
121+
}
122+
97123
@Override
98124
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
99125
return visitor.create(this, taskSettings);
@@ -110,8 +136,8 @@ public GoogleVertexAiChatCompletionServiceSettings getServiceSettings() {
110136
}
111137

112138
@Override
113-
public EmptyTaskSettings getTaskSettings() {
114-
return (EmptyTaskSettings) super.getTaskSettings();
139+
public GoogleVertexAiChatCompletionTaskSettings getTaskSettings() {
140+
return (GoogleVertexAiChatCompletionTaskSettings) super.getTaskSettings();
115141
}
116142

117143
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,12 @@ public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXConten
4545
private final String projectId;
4646

4747
private final RateLimitSettings rateLimitSettings;
48-
private final ThinkingConfig thinkingConfig;
4948

5049
// https://cloud.google.com/vertex-ai/docs/quotas#eval-quotas
5150
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1000);
52-
private static final ThinkingConfig EMPTY_THINKING_CONFIG = new ThinkingConfig();
5351

5452
public GoogleVertexAiChatCompletionServiceSettings(StreamInput in) throws IOException {
55-
this.projectId = in.readString();
56-
this.location = in.readString();
57-
this.modelId = in.readString();
58-
this.rateLimitSettings = new RateLimitSettings(in);
59-
60-
if (in.getTransportVersion().onOrAfter(TransportVersions.GEMINI_THINKING_BUDGET_ADDED)) {
61-
thinkingConfig = new ThinkingConfig(in);
62-
} else {
63-
thinkingConfig = EMPTY_THINKING_CONFIG;
64-
}
53+
this(in.readString(), in.readString(), in.readString(), new RateLimitSettings(in));
6554
}
6655

6756
@Override
@@ -70,7 +59,6 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
7059
builder.field(LOCATION, location);
7160
builder.field(MODEL_ID, modelId);
7261
rateLimitSettings.toXContent(builder, params);
73-
thinkingConfig.toXContent(builder, params);
7462
return builder;
7563
}
7664

@@ -91,28 +79,23 @@ public static GoogleVertexAiChatCompletionServiceSettings fromMap(Map<String, Ob
9179
context
9280
);
9381

94-
// Extract optional thinkingConfig settings
95-
ThinkingConfig thinkingConfig = ThinkingConfig.of(map, validationException, GoogleVertexAiService.NAME, context);
96-
9782
if (validationException.validationErrors().isEmpty() == false) {
9883
throw validationException;
9984
}
10085

101-
return new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings, thinkingConfig);
86+
return new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings);
10287
}
10388

10489
public GoogleVertexAiChatCompletionServiceSettings(
10590
String projectId,
10691
String location,
10792
String modelId,
108-
@Nullable RateLimitSettings rateLimitSettings,
109-
@Nullable ThinkingConfig thinkingConfig
93+
@Nullable RateLimitSettings rateLimitSettings
11094
) {
11195
this.projectId = projectId;
11296
this.location = location;
11397
this.modelId = modelId;
11498
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
115-
this.thinkingConfig = Objects.requireNonNullElse(thinkingConfig, EMPTY_THINKING_CONFIG);
11699
}
117100

118101
public String location() {
@@ -139,10 +122,6 @@ public RateLimitSettings rateLimitSettings() {
139122
return rateLimitSettings;
140123
}
141124

142-
public ThinkingConfig thinkingConfig() {
143-
return thinkingConfig;
144-
}
145-
146125
@Override
147126
public TransportVersion getMinimalSupportedVersion() {
148127
assert false : "should never be called when supportsVersion is used";
@@ -161,9 +140,6 @@ public void writeTo(StreamOutput out) throws IOException {
161140
out.writeString(location);
162141
out.writeString(modelId);
163142
rateLimitSettings.writeTo(out);
164-
if (out.getTransportVersion().onOrAfter(TransportVersions.GEMINI_THINKING_BUDGET_ADDED)) {
165-
thinkingConfig.writeTo(out);
166-
}
167143
}
168144

169145
@Override
@@ -182,13 +158,12 @@ public boolean equals(Object o) {
182158
return Objects.equals(location, that.location)
183159
&& Objects.equals(modelId, that.modelId)
184160
&& Objects.equals(projectId, that.projectId)
185-
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
186-
&& Objects.equals(thinkingConfig, that.thinkingConfig);
161+
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
187162
}
188163

189164
@Override
190165
public int hashCode() {
191-
return Objects.hash(location, modelId, projectId, rateLimitSettings, thinkingConfig);
166+
return Objects.hash(location, modelId, projectId, rateLimitSettings);
192167
}
193168

194169
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai.completion;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.common.ValidationException;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.inference.TaskSettings;
17+
import org.elasticsearch.xcontent.XContentBuilder;
18+
19+
import java.io.IOException;
20+
import java.util.HashMap;
21+
import java.util.Map;
22+
import java.util.Objects;
23+
24+
public class GoogleVertexAiChatCompletionTaskSettings implements TaskSettings {
25+
public static final String NAME = "google_vertex_ai_chatcompletion_task_settings";
26+
27+
private final ThinkingConfig thinkingConfig;
28+
29+
public static final GoogleVertexAiChatCompletionTaskSettings EMPTY_SETTINGS = new GoogleVertexAiChatCompletionTaskSettings();
30+
private static final ThinkingConfig EMPTY_THINKING_CONFIG = new ThinkingConfig();
31+
32+
public GoogleVertexAiChatCompletionTaskSettings() {
33+
thinkingConfig = EMPTY_THINKING_CONFIG;
34+
}
35+
36+
public GoogleVertexAiChatCompletionTaskSettings(ThinkingConfig thinkingConfig) {
37+
this.thinkingConfig = Objects.requireNonNullElse(thinkingConfig, EMPTY_THINKING_CONFIG);
38+
}
39+
40+
public GoogleVertexAiChatCompletionTaskSettings(StreamInput in) throws IOException {
41+
if (in.getTransportVersion().onOrAfter(TransportVersions.GEMINI_THINKING_BUDGET_ADDED)) {
42+
thinkingConfig = new ThinkingConfig(in);
43+
} else {
44+
thinkingConfig = EMPTY_THINKING_CONFIG;
45+
}
46+
}
47+
48+
public static GoogleVertexAiChatCompletionTaskSettings fromMap(Map<String, Object> taskSettings) {
49+
ValidationException validationException = new ValidationException();
50+
51+
// Extract optional thinkingConfig settings
52+
ThinkingConfig thinkingConfig = ThinkingConfig.fromMap(taskSettings, validationException);
53+
54+
if (validationException.validationErrors().isEmpty() == false) {
55+
throw validationException;
56+
}
57+
58+
return new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig);
59+
}
60+
61+
public static GoogleVertexAiChatCompletionTaskSettings of(
62+
GoogleVertexAiChatCompletionTaskSettings originalTaskSettings,
63+
GoogleVertexAiChatCompletionTaskSettings newTaskSettings
64+
) {
65+
ThinkingConfig thinkingConfig = newTaskSettings.thinkingConfig().isEmpty()
66+
? originalTaskSettings.thinkingConfig()
67+
: newTaskSettings.thinkingConfig();
68+
return new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig);
69+
}
70+
71+
public ThinkingConfig thinkingConfig() {
72+
return thinkingConfig;
73+
}
74+
75+
@Override
76+
public boolean isEmpty() {
77+
return thinkingConfig.isEmpty();
78+
}
79+
80+
@Override
81+
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
82+
GoogleVertexAiChatCompletionTaskSettings newTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap(
83+
new HashMap<>(newSettings)
84+
);
85+
return GoogleVertexAiChatCompletionTaskSettings.of(this, newTaskSettings);
86+
}
87+
88+
@Override
89+
public String getWriteableName() {
90+
return NAME;
91+
}
92+
93+
@Override
94+
public TransportVersion getMinimalSupportedVersion() {
95+
return TransportVersions.GEMINI_THINKING_BUDGET_ADDED;
96+
}
97+
98+
@Override
99+
public void writeTo(StreamOutput out) throws IOException {
100+
if (out.getTransportVersion().onOrAfter(TransportVersions.GEMINI_THINKING_BUDGET_ADDED)) {
101+
thinkingConfig.writeTo(out);
102+
}
103+
}
104+
105+
@Override
106+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
107+
builder.startObject();
108+
thinkingConfig.toXContent(builder, params);
109+
builder.endObject();
110+
return builder;
111+
}
112+
113+
@Override
114+
public boolean equals(Object o) {
115+
if (o == null || getClass() != o.getClass()) return false;
116+
GoogleVertexAiChatCompletionTaskSettings that = (GoogleVertexAiChatCompletionTaskSettings) o;
117+
return Objects.equals(thinkingConfig, that.thinkingConfig);
118+
}
119+
120+
@Override
121+
public int hashCode() {
122+
return Objects.hashCode(thinkingConfig);
123+
}
124+
125+
@Override
126+
public String toString() {
127+
return Strings.toString(this);
128+
}
129+
}

0 commit comments

Comments
 (0)