Skip to content

Commit dd5bb0e

Browse files
committed
[ML] Bedrock Cohere Task Settings Support
Add support for Cohere Task Settings, InputType and Truncate, through the Amazon Bedrock provider integration. Task Settings can now be passed bother during Inference endpoint creation and Inference POST requests. Close #126156
1 parent 6e4cb81 commit dd5bb0e

File tree

15 files changed

+522
-102
lines changed

15 files changed

+522
-102
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ static TransportVersion def(int id) {
157157
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
158158
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
159159
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
160+
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
160161
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
161162
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
162163
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -214,6 +215,7 @@ static TransportVersion def(int id) {
214215
public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00);
215216
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
216217
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
218+
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_048_00_0);
217219

218220
/*
219221
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
4242
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
4343
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
44+
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
4445
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings;
4546
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings;
4647
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
@@ -173,8 +174,13 @@ private static void addAmazonBedrockNamedWriteables(List<NamedWriteableRegistry.
173174
AmazonBedrockEmbeddingsServiceSettings::new
174175
)
175176
);
176-
177-
// no task settings for Amazon Bedrock Embeddings
177+
namedWriteables.add(
178+
new NamedWriteableRegistry.Entry(
179+
TaskSettings.class,
180+
AmazonBedrockEmbeddingsTaskSettings.NAME,
181+
AmazonBedrockEmbeddingsTaskSettings::new
182+
)
183+
);
178184

179185
namedWriteables.add(
180186
new NamedWriteableRegistry.Entry(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,31 @@
1111
import org.elasticsearch.inference.InputType;
1212
import org.elasticsearch.xcontent.ToXContentObject;
1313
import org.elasticsearch.xcontent.XContentBuilder;
14+
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
1415

1516
import java.io.IOException;
1617
import java.util.List;
1718
import java.util.Objects;
1819

1920
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
2021

21-
public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nullable InputType inputType) implements ToXContentObject {
22+
public record AmazonBedrockCohereEmbeddingsRequestEntity(
23+
List<String> input,
24+
@Nullable InputType inputType,
25+
AmazonBedrockEmbeddingsTaskSettings taskSettings
26+
) implements ToXContentObject {
2227

2328
private static final String TEXTS_FIELD = "texts";
2429
private static final String INPUT_TYPE_FIELD = "input_type";
2530
private static final String SEARCH_DOCUMENT = "search_document";
2631
private static final String SEARCH_QUERY = "search_query";
2732
private static final String CLUSTERING = "clustering";
2833
private static final String CLASSIFICATION = "classification";
34+
private static final String TRUNCATE = "truncate";
2935

3036
public AmazonBedrockCohereEmbeddingsRequestEntity {
3137
Objects.requireNonNull(input);
38+
Objects.requireNonNull(taskSettings);
3239
}
3340

3441
@Override
@@ -38,11 +45,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
3845

3946
if (InputType.isSpecified(inputType)) {
4047
builder.field(INPUT_TYPE_FIELD, convertToString(inputType));
48+
} else if (InputType.isSpecified(taskSettings.inputType())) {
49+
builder.field(INPUT_TYPE_FIELD, convertToString(taskSettings.inputType()));
4150
} else {
4251
// input_type is required so default to document
4352
builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT);
4453
}
4554

55+
if (taskSettings.cohereTruncation() != null) {
56+
builder.field(TRUNCATE, taskSettings.cohereTruncation());
57+
}
58+
4659
builder.endObject();
4760
return builder;
4861
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public static ToXContent createEntity(
3939
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
4040
}
4141
case COHERE -> {
42-
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType);
42+
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
4343
}
4444
default -> {
4545
return null;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ protected void executeRequest(AmazonBedrockBaseClient client) {
7676

7777
@Override
7878
public Request truncate() {
79+
if (provider == AmazonBedrockProvider.COHERE) {
80+
return this; // Cohere has its own truncation logic
81+
}
7982
var truncatedInput = truncator.truncate(truncationResult.input());
8083
return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout);
8184
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ public class AmazonBedrockConstants {
1919
public static final String TOP_K_FIELD = "top_k";
2020
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";
2121

22+
public static final String INPUT_TYPE_FIELD = "input_type";
23+
public static final String TRUNCATE_FIELD = "truncate";
24+
2225
public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0;
2326
public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0;
2427

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ private static AmazonBedrockModel createModel(
304304
context
305305
);
306306
checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider());
307+
checkTaskSettingsForTextEmbeddingModel(model);
307308
return model;
308309
}
309310
case COMPLETION -> {
@@ -382,6 +383,17 @@ private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvide
382383
}
383384
}
384385

386+
private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) {
387+
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) {
388+
throw new ElasticsearchStatusException(
389+
"The [%s] task type for provider [%s] does not allow [truncate] field",
390+
RestStatus.BAD_REQUEST,
391+
TaskType.TEXT_EMBEDDING,
392+
model.provider()
393+
);
394+
}
395+
}
396+
385397
private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) {
386398
var taskSettings = model.getTaskSettings();
387399
if (taskSettings.topK() != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77

88
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
99

10-
import org.elasticsearch.common.ValidationException;
1110
import org.elasticsearch.inference.ChunkingSettings;
12-
import org.elasticsearch.inference.EmptyTaskSettings;
1311
import org.elasticsearch.inference.Model;
1412
import org.elasticsearch.inference.ModelConfigurations;
1513
import org.elasticsearch.inference.ModelSecrets;
1614
import org.elasticsearch.inference.ServiceSettings;
17-
import org.elasticsearch.inference.TaskSettings;
1815
import org.elasticsearch.inference.TaskType;
1916
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
2017
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
2825

2926
public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
3027
if (taskSettings != null && taskSettings.isEmpty() == false) {
31-
// no task settings allowed
32-
var validationException = new ValidationException();
33-
validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings");
34-
throw validationException;
28+
var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings);
29+
return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings);
3530
}
3631

3732
return embeddingsModel;
@@ -52,7 +47,7 @@ public AmazonBedrockEmbeddingsModel(
5247
taskType,
5348
service,
5449
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
55-
new EmptyTaskSettings(),
50+
AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings),
5651
chunkingSettings,
5752
AwsSecretSettings.fromMap(secretSettings)
5853
);
@@ -63,12 +58,12 @@ public AmazonBedrockEmbeddingsModel(
6358
TaskType taskType,
6459
String service,
6560
AmazonBedrockEmbeddingsServiceSettings serviceSettings,
66-
TaskSettings taskSettings,
61+
AmazonBedrockEmbeddingsTaskSettings taskSettings,
6762
ChunkingSettings chunkingSettings,
6863
AwsSecretSettings secrets
6964
) {
7065
super(
71-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
66+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
7267
new ModelSecrets(secrets)
7368
);
7469
}
@@ -77,6 +72,10 @@ public AmazonBedrockEmbeddingsModel(Model model, ServiceSettings serviceSettings
7772
super(model, serviceSettings);
7873
}
7974

75+
public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) {
76+
super(model, taskSettings);
77+
}
78+
8079
@Override
8180
public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) {
8281
return creator.create(this, taskSettings);
@@ -86,4 +85,9 @@ public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, O
8685
public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() {
8786
return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings();
8887
}
88+
89+
@Override
90+
public AmazonBedrockEmbeddingsTaskSettings getTaskSettings() {
91+
return (AmazonBedrockEmbeddingsTaskSettings) super.getTaskSettings();
92+
}
8993
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.ValidationException;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.core.Nullable;
16+
import org.elasticsearch.inference.InputType;
17+
import org.elasticsearch.inference.ModelConfigurations;
18+
import org.elasticsearch.inference.TaskSettings;
19+
import org.elasticsearch.xcontent.XContentBuilder;
20+
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
21+
22+
import java.io.IOException;
23+
import java.util.EnumSet;
24+
import java.util.HashMap;
25+
import java.util.Map;
26+
27+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
28+
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.INPUT_TYPE_FIELD;
29+
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
30+
31+
public record AmazonBedrockEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable CohereTruncation cohereTruncation)
32+
implements
33+
TaskSettings {
34+
public static final String NAME = "amazon_bedrock_embeddings_task_settings";
35+
private static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings(null, null);
36+
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
37+
InputType.INGEST,
38+
InputType.INTERNAL_INGEST,
39+
InputType.SEARCH,
40+
InputType.INTERNAL_SEARCH,
41+
InputType.CLASSIFICATION,
42+
InputType.CLUSTERING
43+
);
44+
45+
public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
46+
if (map == null || map.isEmpty()) {
47+
return EMPTY;
48+
}
49+
50+
ValidationException validationException = new ValidationException();
51+
52+
var inputType = extractOptionalEnum(
53+
map,
54+
INPUT_TYPE_FIELD,
55+
ModelConfigurations.TASK_SETTINGS,
56+
InputType::fromString,
57+
VALID_INPUT_TYPE_VALUES,
58+
validationException
59+
);
60+
61+
var cohereTruncation = extractOptionalEnum(
62+
map,
63+
TRUNCATE_FIELD,
64+
ModelConfigurations.TASK_SETTINGS,
65+
CohereTruncation::fromString,
66+
CohereTruncation.ALL,
67+
validationException
68+
);
69+
70+
if (validationException.validationErrors().isEmpty() == false) {
71+
throw validationException;
72+
}
73+
74+
return new AmazonBedrockEmbeddingsTaskSettings(inputType, cohereTruncation);
75+
}
76+
77+
public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException {
78+
this(in.readOptionalEnum(InputType.class), in.readOptionalEnum(CohereTruncation.class));
79+
}
80+
81+
@Override
82+
public boolean isEmpty() {
83+
return inputType() == null && cohereTruncation() == null;
84+
}
85+
86+
@Override
87+
public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
88+
var newTaskSettings = fromMap(new HashMap<>(newSettings));
89+
90+
return new AmazonBedrockEmbeddingsTaskSettings(
91+
firstNonNullOrNull(newTaskSettings.inputType(), inputType()),
92+
firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation())
93+
);
94+
}
95+
96+
private static <T> T firstNonNullOrNull(T first, T second) {
97+
return first != null ? first : second;
98+
}
99+
100+
@Override
101+
public String getWriteableName() {
102+
return NAME;
103+
}
104+
105+
@Override
106+
public TransportVersion getMinimalSupportedVersion() {
107+
return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS;
108+
}
109+
110+
@Override
111+
public void writeTo(StreamOutput out) throws IOException {
112+
out.writeOptionalEnum(inputType());
113+
out.writeOptionalEnum(cohereTruncation());
114+
}
115+
116+
@Override
117+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
118+
builder.startObject();
119+
if (inputType != null) {
120+
builder.field(INPUT_TYPE_FIELD, inputType);
121+
}
122+
if (cohereTruncation != null) {
123+
builder.field(TRUNCATE_FIELD, cohereTruncation);
124+
}
125+
return builder.endObject();
126+
}
127+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere;
99

10+
import java.util.EnumSet;
1011
import java.util.Locale;
1112

1213
/**
@@ -31,6 +32,8 @@ public enum CohereTruncation {
3132
*/
3233
END;
3334

35+
public static final EnumSet<CohereTruncation> ALL = EnumSet.allOf(CohereTruncation.class);
36+
3437
@Override
3538
public String toString() {
3639
return name().toLowerCase(Locale.ROOT);

0 commit comments

Comments
 (0)