Skip to content

Commit 7039a1d

Browse files
authored
Adds support for input_type field to Vertex inference service (#116431)
* Adding input type to google vertex ai service * Update docs/changelog/116431.yaml * PR feedback - backwards compatibility * Fix lint error
1 parent a71c132 commit 7039a1d

File tree

18 files changed

+697
-106
lines changed

18 files changed

+697
-106
lines changed

docs/changelog/116431.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 116431
2+
summary: Adds support for `input_type` field to Vertex inference service
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ static TransportVersion def(int id) {
193193
public static final TransportVersion ROLE_MONITOR_STATS = def(8_787_00_0);
194194
public static final TransportVersion DATA_STREAM_INDEX_VERSION_DEPRECATION_CHECK = def(8_788_00_0);
195195
public static final TransportVersion ADD_COMPATIBILITY_VERSIONS_TO_NODE_INFO = def(8_789_00_0);
196+
public static final TransportVersion VERTEX_AI_INPUT_TYPE_ADDED = def(8_790_00_0);
196197

197198
/*
198199
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionCreator.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.external.action.googlevertexai;
99

10+
import org.elasticsearch.inference.InputType;
1011
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1112
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1213
import org.elasticsearch.xpack.inference.external.http.sender.GoogleVertexAiEmbeddingsRequestManager;
@@ -33,9 +34,10 @@ public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceCompo
3334
}
3435

3536
@Override
36-
public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings) {
37+
public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
38+
var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, taskSettings, inputType);
3739
var requestManager = new GoogleVertexAiEmbeddingsRequestManager(
38-
model,
40+
overriddenModel,
3941
serviceComponents.truncator(),
4042
serviceComponents.threadPool()
4143
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionVisitor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.external.action.googlevertexai;
99

10+
import org.elasticsearch.inference.InputType;
1011
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1112
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
1213
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
@@ -15,7 +16,7 @@
1516

1617
public interface GoogleVertexAiActionVisitor {
1718

18-
ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings);
19+
ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
1920

2021
ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);
2122
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public HttpRequest createHttpRequest() {
4040
HttpPost httpPost = new HttpPost(model.uri());
4141

4242
ByteArrayEntity byteEntity = new ByteArrayEntity(
43-
Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), model.getTaskSettings().autoTruncate()))
43+
Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), model.getTaskSettings()))
4444
.getBytes(StandardCharsets.UTF_8)
4545
);
4646

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntity.java

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,35 @@
77

88
package org.elasticsearch.xpack.inference.external.request.googlevertexai;
99

10-
import org.elasticsearch.core.Nullable;
10+
import org.elasticsearch.inference.InputType;
1111
import org.elasticsearch.xcontent.ToXContentObject;
1212
import org.elasticsearch.xcontent.XContentBuilder;
13+
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
1314

1415
import java.io.IOException;
1516
import java.util.List;
1617
import java.util.Objects;
1718

18-
public record GoogleVertexAiEmbeddingsRequestEntity(List<String> inputs, @Nullable Boolean autoTruncation) implements ToXContentObject {
19+
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;
20+
21+
public record GoogleVertexAiEmbeddingsRequestEntity(List<String> inputs, GoogleVertexAiEmbeddingsTaskSettings taskSettings)
22+
implements
23+
ToXContentObject {
1924

2025
private static final String INSTANCES_FIELD = "instances";
2126
private static final String CONTENT_FIELD = "content";
2227
private static final String PARAMETERS_FIELD = "parameters";
2328
private static final String AUTO_TRUNCATE_FIELD = "autoTruncate";
29+
private static final String TASK_TYPE_FIELD = "task_type";
30+
31+
private static final String CLASSIFICATION_TASK_TYPE = "CLASSIFICATION";
32+
private static final String CLUSTERING_TASK_TYPE = "CLUSTERING";
33+
private static final String RETRIEVAL_DOCUMENT_TASK_TYPE = "RETRIEVAL_DOCUMENT";
34+
private static final String RETRIEVAL_QUERY_TASK_TYPE = "RETRIEVAL_QUERY";
2435

2536
public GoogleVertexAiEmbeddingsRequestEntity {
2637
Objects.requireNonNull(inputs);
38+
Objects.requireNonNull(taskSettings);
2739
}
2840

2941
@Override
@@ -35,21 +47,38 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
3547
builder.startObject();
3648
{
3749
builder.field(CONTENT_FIELD, input);
50+
51+
if (taskSettings.getInputType() != null) {
52+
builder.field(TASK_TYPE_FIELD, convertToString(taskSettings.getInputType()));
53+
}
3854
}
3955
builder.endObject();
4056
}
4157

4258
builder.endArray();
4359

44-
if (autoTruncation != null) {
60+
if (taskSettings.autoTruncate() != null) {
4561
builder.startObject(PARAMETERS_FIELD);
4662
{
47-
builder.field(AUTO_TRUNCATE_FIELD, autoTruncation);
63+
builder.field(AUTO_TRUNCATE_FIELD, taskSettings.autoTruncate());
4864
}
4965
builder.endObject();
5066
}
5167
builder.endObject();
5268

5369
return builder;
5470
}
71+
72+
static String convertToString(InputType inputType) {
73+
return switch (inputType) {
74+
case INGEST -> RETRIEVAL_DOCUMENT_TASK_TYPE;
75+
case SEARCH -> RETRIEVAL_QUERY_TASK_TYPE;
76+
case CLASSIFICATION -> CLASSIFICATION_TASK_TYPE;
77+
case CLUSTERING -> CLUSTERING_TASK_TYPE;
78+
default -> {
79+
assert false : invalidInputTypeMessage(inputType);
80+
yield null;
81+
}
82+
};
83+
}
5584
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,25 @@
77

88
package org.elasticsearch.xpack.inference.services.googlevertexai;
99

10+
import org.elasticsearch.inference.InputType;
1011
import org.elasticsearch.inference.Model;
1112
import org.elasticsearch.inference.ModelConfigurations;
1213
import org.elasticsearch.inference.ModelSecrets;
1314
import org.elasticsearch.inference.ServiceSettings;
15+
import org.elasticsearch.inference.TaskSettings;
1416
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1517
import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionVisitor;
1618

19+
import java.net.URI;
1720
import java.util.Map;
1821
import java.util.Objects;
1922

2023
public abstract class GoogleVertexAiModel extends Model {
2124

2225
private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;
2326

27+
protected URI uri;
28+
2429
public GoogleVertexAiModel(
2530
ModelConfigurations configurations,
2631
ModelSecrets secrets,
@@ -34,13 +39,24 @@ public GoogleVertexAiModel(
3439
public GoogleVertexAiModel(GoogleVertexAiModel model, ServiceSettings serviceSettings) {
3540
super(model, serviceSettings);
3641

42+
uri = model.uri();
43+
rateLimitServiceSettings = model.rateLimitServiceSettings();
44+
}
45+
46+
public GoogleVertexAiModel(GoogleVertexAiModel model, TaskSettings taskSettings) {
47+
super(model, taskSettings);
48+
49+
uri = model.uri();
3750
rateLimitServiceSettings = model.rateLimitServiceSettings();
3851
}
3952

40-
public abstract ExecutableAction accept(GoogleVertexAiActionVisitor creator, Map<String, Object> taskSettings);
53+
public abstract ExecutableAction accept(GoogleVertexAiActionVisitor creator, Map<String, Object> taskSettings, InputType inputType);
4154

4255
public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
4356
return rateLimitServiceSettings;
4457
}
4558

59+
public URI uri() {
60+
return uri;
61+
}
4662
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ protected void doInfer(
210210

211211
var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());
212212

213-
var action = googleVertexAiModel.accept(actionCreator, taskSettings);
213+
var action = googleVertexAiModel.accept(actionCreator, taskSettings, inputType);
214214
action.execute(inputs, timeout, listener);
215215
}
216216

@@ -235,7 +235,7 @@ protected void doChunkedInfer(
235235
).batchRequestsWithListeners(listener);
236236

237237
for (var request : batchedRequests) {
238-
var action = googleVertexAiModel.accept(actionCreator, taskSettings);
238+
var action = googleVertexAiModel.accept(actionCreator, taskSettings, inputType);
239239
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
240240
}
241241
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
import org.elasticsearch.common.util.LazyInitializable;
1212
import org.elasticsearch.core.Nullable;
1313
import org.elasticsearch.inference.ChunkingSettings;
14+
import org.elasticsearch.inference.InputType;
1415
import org.elasticsearch.inference.ModelConfigurations;
1516
import org.elasticsearch.inference.ModelSecrets;
1617
import org.elasticsearch.inference.SettingsConfiguration;
1718
import org.elasticsearch.inference.TaskType;
1819
import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType;
1920
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
21+
import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption;
2022
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2123
import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionVisitor;
2224
import org.elasticsearch.xpack.inference.external.request.googlevertexai.GoogleVertexAiUtils;
@@ -29,13 +31,25 @@
2931
import java.util.Collections;
3032
import java.util.HashMap;
3133
import java.util.Map;
34+
import java.util.stream.Stream;
3235

3336
import static org.elasticsearch.core.Strings.format;
3437
import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE;
38+
import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE;
3539

3640
public class GoogleVertexAiEmbeddingsModel extends GoogleVertexAiModel {
3741

38-
private URI uri;
42+
public static GoogleVertexAiEmbeddingsModel of(
43+
GoogleVertexAiEmbeddingsModel model,
44+
Map<String, Object> taskSettings,
45+
InputType inputType
46+
) {
47+
var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(taskSettings);
48+
return new GoogleVertexAiEmbeddingsModel(
49+
model,
50+
GoogleVertexAiEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType)
51+
);
52+
}
3953

4054
public GoogleVertexAiEmbeddingsModel(
4155
String inferenceEntityId,
@@ -62,6 +76,10 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google
6276
super(model, serviceSettings);
6377
}
6478

79+
public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, GoogleVertexAiEmbeddingsTaskSettings taskSettings) {
80+
super(model, taskSettings);
81+
}
82+
6583
// Should only be used directly for testing
6684
GoogleVertexAiEmbeddingsModel(
6785
String inferenceEntityId,
@@ -126,13 +144,9 @@ public GoogleVertexAiEmbeddingsRateLimitServiceSettings rateLimitServiceSettings
126144
return (GoogleVertexAiEmbeddingsRateLimitServiceSettings) super.rateLimitServiceSettings();
127145
}
128146

129-
public URI uri() {
130-
return uri;
131-
}
132-
133147
@Override
134-
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
135-
return visitor.create(this, taskSettings);
148+
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
149+
return visitor.create(this, taskSettings, inputType);
136150
}
137151

138152
public static URI buildUri(String location, String projectId, String modelId) throws URISyntaxException {
@@ -161,11 +175,32 @@ public static Map<String, SettingsConfiguration> get() {
161175
new LazyInitializable<>(() -> {
162176
var configurationMap = new HashMap<String, SettingsConfiguration>();
163177

178+
configurationMap.put(
179+
INPUT_TYPE,
180+
new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.DROPDOWN)
181+
.setLabel("Input Type")
182+
.setOrder(1)
183+
.setRequired(false)
184+
.setSensitive(false)
185+
.setTooltip("Specifies the type of input passed to the model.")
186+
.setType(SettingsConfigurationFieldType.STRING)
187+
.setOptions(
188+
Stream.of(
189+
InputType.CLASSIFICATION.toString(),
190+
InputType.CLUSTERING.toString(),
191+
InputType.INGEST.toString(),
192+
InputType.SEARCH.toString()
193+
).map(v -> new SettingsConfigurationSelectOption.Builder().setLabelAndValue(v).build()).toList()
194+
)
195+
.setValue("")
196+
.build()
197+
);
198+
164199
configurationMap.put(
165200
AUTO_TRUNCATE,
166201
new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.TOGGLE)
167202
.setLabel("Auto Truncate")
168-
.setOrder(1)
203+
.setOrder(2)
169204
.setRequired(false)
170205
.setSensitive(false)
171206
.setTooltip("Specifies if the API truncates inputs longer than the maximum token length automatically.")

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRequestTaskSettings.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,46 @@
99

1010
import org.elasticsearch.common.ValidationException;
1111
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.InputType;
13+
import org.elasticsearch.inference.ModelConfigurations;
1214

1315
import java.util.Map;
1416

1517
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
18+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
19+
import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE;
20+
import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.VALID_REQUEST_VALUES;
1621

17-
public record GoogleVertexAiEmbeddingsRequestTaskSettings(@Nullable Boolean autoTruncate) {
22+
public record GoogleVertexAiEmbeddingsRequestTaskSettings(@Nullable Boolean autoTruncate, @Nullable InputType inputType) {
1823

19-
public static final GoogleVertexAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsRequestTaskSettings(null);
24+
public static final GoogleVertexAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsRequestTaskSettings(
25+
null,
26+
null
27+
);
2028

2129
public static GoogleVertexAiEmbeddingsRequestTaskSettings fromMap(Map<String, Object> map) {
22-
if (map.isEmpty()) {
23-
return GoogleVertexAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS;
30+
if (map == null || map.isEmpty()) {
31+
return EMPTY_SETTINGS;
2432
}
2533

2634
ValidationException validationException = new ValidationException();
2735

36+
InputType inputType = extractOptionalEnum(
37+
map,
38+
INPUT_TYPE,
39+
ModelConfigurations.TASK_SETTINGS,
40+
InputType::fromString,
41+
VALID_REQUEST_VALUES,
42+
validationException
43+
);
44+
2845
Boolean autoTruncate = extractOptionalBoolean(map, GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, validationException);
2946

3047
if (validationException.validationErrors().isEmpty() == false) {
3148
throw validationException;
3249
}
3350

34-
return new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate);
51+
return new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate, inputType);
3552
}
3653

3754
}

0 commit comments

Comments
 (0)