Skip to content

Commit 85b507d

Browse files
authored
[ML] Bedrock Cohere Task Settings Support (#126493) (#126559)
Add support for Cohere Task Settings and Truncate, through the Amazon Bedrock provider integration. Task Settings can now be passed bother during Inference endpoint creation and Inference POST requests. Relate #126156
1 parent 2e385d3 commit 85b507d

File tree

16 files changed

+469
-102
lines changed

16 files changed

+469
-102
lines changed

docs/changelog/126493.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 126493
2+
summary: Bedrock Cohere Task Settings Support
3+
area: Machine Learning
4+
type: enhancement
5+
issues:
6+
- 126156

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ static TransportVersion def(int id) {
201201
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
202202
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
203203
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
204+
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
204205

205206
/*
206207
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
4242
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
4343
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
44+
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
4445
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings;
4546
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings;
4647
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
@@ -173,8 +174,13 @@ private static void addAmazonBedrockNamedWriteables(List<NamedWriteableRegistry.
173174
AmazonBedrockEmbeddingsServiceSettings::new
174175
)
175176
);
176-
177-
// no task settings for Amazon Bedrock Embeddings
177+
namedWriteables.add(
178+
new NamedWriteableRegistry.Entry(
179+
TaskSettings.class,
180+
AmazonBedrockEmbeddingsTaskSettings.NAME,
181+
AmazonBedrockEmbeddingsTaskSettings::new
182+
)
183+
);
178184

179185
namedWriteables.add(
180186
new NamedWriteableRegistry.Entry(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,31 @@
1111
import org.elasticsearch.inference.InputType;
1212
import org.elasticsearch.xcontent.ToXContentObject;
1313
import org.elasticsearch.xcontent.XContentBuilder;
14+
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
1415

1516
import java.io.IOException;
1617
import java.util.List;
1718
import java.util.Objects;
1819

1920
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
2021

21-
public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nullable InputType inputType) implements ToXContentObject {
22+
public record AmazonBedrockCohereEmbeddingsRequestEntity(
23+
List<String> input,
24+
@Nullable InputType inputType,
25+
AmazonBedrockEmbeddingsTaskSettings taskSettings
26+
) implements ToXContentObject {
2227

2328
private static final String TEXTS_FIELD = "texts";
2429
private static final String INPUT_TYPE_FIELD = "input_type";
2530
private static final String SEARCH_DOCUMENT = "search_document";
2631
private static final String SEARCH_QUERY = "search_query";
2732
private static final String CLUSTERING = "clustering";
2833
private static final String CLASSIFICATION = "classification";
34+
private static final String TRUNCATE = "truncate";
2935

3036
public AmazonBedrockCohereEmbeddingsRequestEntity {
3137
Objects.requireNonNull(input);
38+
Objects.requireNonNull(taskSettings);
3239
}
3340

3441
@Override
@@ -43,6 +50,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4350
builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT);
4451
}
4552

53+
if (taskSettings.cohereTruncation() != null) {
54+
builder.field(TRUNCATE, taskSettings.cohereTruncation().name());
55+
}
56+
4657
builder.endObject();
4758
return builder;
4859
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public static ToXContent createEntity(
3939
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
4040
}
4141
case COHERE -> {
42-
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType);
42+
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
4343
}
4444
default -> {
4545
return null;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ protected void executeRequest(AmazonBedrockBaseClient client) {
7676

7777
@Override
7878
public Request truncate() {
79+
if (provider == AmazonBedrockProvider.COHERE) {
80+
return this; // Cohere has its own truncation logic
81+
}
7982
var truncatedInput = truncator.truncate(truncationResult.input());
8083
return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout);
8184
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public class AmazonBedrockConstants {
1919
public static final String TOP_K_FIELD = "top_k";
2020
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";
2121

22+
public static final String TRUNCATE_FIELD = "truncate";
23+
2224
public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0;
2325
public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0;
2426

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ private static AmazonBedrockModel createModel(
303303
context
304304
);
305305
checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider());
306+
checkTaskSettingsForTextEmbeddingModel(model);
306307
return model;
307308
}
308309
case COMPLETION -> {
@@ -368,6 +369,17 @@ private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvide
368369
}
369370
}
370371

372+
private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) {
373+
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) {
374+
throw new ElasticsearchStatusException(
375+
"The [{}] task type for provider [{}] does not allow [truncate] field",
376+
RestStatus.BAD_REQUEST,
377+
TaskType.TEXT_EMBEDDING,
378+
model.provider()
379+
);
380+
}
381+
}
382+
371383
private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) {
372384
var taskSettings = model.getTaskSettings();
373385
if (taskSettings.topK() != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77

88
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
99

10-
import org.elasticsearch.common.ValidationException;
1110
import org.elasticsearch.inference.ChunkingSettings;
12-
import org.elasticsearch.inference.EmptyTaskSettings;
1311
import org.elasticsearch.inference.Model;
1412
import org.elasticsearch.inference.ModelConfigurations;
1513
import org.elasticsearch.inference.ModelSecrets;
1614
import org.elasticsearch.inference.ServiceSettings;
17-
import org.elasticsearch.inference.TaskSettings;
1815
import org.elasticsearch.inference.TaskType;
1916
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
2017
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
2825

2926
public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
3027
if (taskSettings != null && taskSettings.isEmpty() == false) {
31-
// no task settings allowed
32-
var validationException = new ValidationException();
33-
validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings");
34-
throw validationException;
28+
var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings);
29+
return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings);
3530
}
3631

3732
return embeddingsModel;
@@ -52,7 +47,7 @@ public AmazonBedrockEmbeddingsModel(
5247
taskType,
5348
service,
5449
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
55-
new EmptyTaskSettings(),
50+
AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings),
5651
chunkingSettings,
5752
AwsSecretSettings.fromMap(secretSettings)
5853
);
@@ -63,12 +58,12 @@ public AmazonBedrockEmbeddingsModel(
6358
TaskType taskType,
6459
String service,
6560
AmazonBedrockEmbeddingsServiceSettings serviceSettings,
66-
TaskSettings taskSettings,
61+
AmazonBedrockEmbeddingsTaskSettings taskSettings,
6762
ChunkingSettings chunkingSettings,
6863
AwsSecretSettings secrets
6964
) {
7065
super(
71-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
66+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
7267
new ModelSecrets(secrets)
7368
);
7469
}
@@ -77,6 +72,10 @@ public AmazonBedrockEmbeddingsModel(Model model, ServiceSettings serviceSettings
7772
super(model, serviceSettings);
7873
}
7974

75+
public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) {
76+
super(model, taskSettings);
77+
}
78+
8079
@Override
8180
public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) {
8281
return creator.create(this, taskSettings);
@@ -86,4 +85,9 @@ public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, O
8685
public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() {
8786
return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings();
8887
}
88+
89+
@Override
90+
public AmazonBedrockEmbeddingsTaskSettings getTaskSettings() {
91+
return (AmazonBedrockEmbeddingsTaskSettings) super.getTaskSettings();
92+
}
8993
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.ValidationException;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.core.Nullable;
16+
import org.elasticsearch.inference.ModelConfigurations;
17+
import org.elasticsearch.inference.TaskSettings;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
20+
21+
import java.io.IOException;
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
25+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
26+
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
27+
28+
public record AmazonBedrockEmbeddingsTaskSettings(@Nullable CohereTruncation cohereTruncation) implements TaskSettings {
29+
public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((CohereTruncation) null);
30+
public static final String NAME = "amazon_bedrock_embeddings_task_settings";
31+
32+
public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
33+
if (map == null || map.isEmpty()) {
34+
return EMPTY;
35+
}
36+
37+
ValidationException validationException = new ValidationException();
38+
39+
var cohereTruncation = extractOptionalEnum(
40+
map,
41+
TRUNCATE_FIELD,
42+
ModelConfigurations.TASK_SETTINGS,
43+
CohereTruncation::fromString,
44+
CohereTruncation.ALL,
45+
validationException
46+
);
47+
48+
if (validationException.validationErrors().isEmpty() == false) {
49+
throw validationException;
50+
}
51+
52+
return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation);
53+
}
54+
55+
public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException {
56+
this(in.readOptionalEnum(CohereTruncation.class));
57+
}
58+
59+
@Override
60+
public boolean isEmpty() {
61+
return cohereTruncation() == null;
62+
}
63+
64+
@Override
65+
public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
66+
var newTaskSettings = fromMap(new HashMap<>(newSettings));
67+
68+
return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation()));
69+
}
70+
71+
private static <T> T firstNonNullOrNull(T first, T second) {
72+
return first != null ? first : second;
73+
}
74+
75+
@Override
76+
public String getWriteableName() {
77+
return NAME;
78+
}
79+
80+
@Override
81+
public TransportVersion getMinimalSupportedVersion() {
82+
return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS_8_19;
83+
}
84+
85+
@Override
86+
public void writeTo(StreamOutput out) throws IOException {
87+
out.writeOptionalEnum(cohereTruncation());
88+
}
89+
90+
@Override
91+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
92+
builder.startObject();
93+
if (cohereTruncation != null) {
94+
builder.field(TRUNCATE_FIELD, cohereTruncation);
95+
}
96+
return builder.endObject();
97+
}
98+
}

0 commit comments

Comments
 (0)