Skip to content

Commit 85b507d

Browse files
authored
[ML] Bedrock Cohere Task Settings Support (elastic#126493) (elastic#126559)
Add support for Cohere Task Settings and Truncate, through the Amazon Bedrock provider integration. Task Settings can now be passed bother during Inference endpoint creation and Inference POST requests. Relate elastic#126156
1 parent 2e385d3 commit 85b507d

File tree

16 files changed

+469
-102
lines changed

16 files changed

+469
-102
lines changed

docs/changelog/126493.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 126493
2+
summary: Bedrock Cohere Task Settings Support
3+
area: Machine Learning
4+
type: enhancement
5+
issues:
6+
- 126156

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ static TransportVersion def(int id) {
201201
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
202202
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
203203
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
204+
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
204205

205206
/*
206207
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
4242
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
4343
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
44+
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
4445
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings;
4546
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings;
4647
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
@@ -173,8 +174,13 @@ private static void addAmazonBedrockNamedWriteables(List<NamedWriteableRegistry.
173174
AmazonBedrockEmbeddingsServiceSettings::new
174175
)
175176
);
176-
177-
// no task settings for Amazon Bedrock Embeddings
177+
namedWriteables.add(
178+
new NamedWriteableRegistry.Entry(
179+
TaskSettings.class,
180+
AmazonBedrockEmbeddingsTaskSettings.NAME,
181+
AmazonBedrockEmbeddingsTaskSettings::new
182+
)
183+
);
178184

179185
namedWriteables.add(
180186
new NamedWriteableRegistry.Entry(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,31 @@
1111
import org.elasticsearch.inference.InputType;
1212
import org.elasticsearch.xcontent.ToXContentObject;
1313
import org.elasticsearch.xcontent.XContentBuilder;
14+
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
1415

1516
import java.io.IOException;
1617
import java.util.List;
1718
import java.util.Objects;
1819

1920
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
2021

21-
public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nullable InputType inputType) implements ToXContentObject {
22+
public record AmazonBedrockCohereEmbeddingsRequestEntity(
23+
List<String> input,
24+
@Nullable InputType inputType,
25+
AmazonBedrockEmbeddingsTaskSettings taskSettings
26+
) implements ToXContentObject {
2227

2328
private static final String TEXTS_FIELD = "texts";
2429
private static final String INPUT_TYPE_FIELD = "input_type";
2530
private static final String SEARCH_DOCUMENT = "search_document";
2631
private static final String SEARCH_QUERY = "search_query";
2732
private static final String CLUSTERING = "clustering";
2833
private static final String CLASSIFICATION = "classification";
34+
private static final String TRUNCATE = "truncate";
2935

3036
public AmazonBedrockCohereEmbeddingsRequestEntity {
3137
Objects.requireNonNull(input);
38+
Objects.requireNonNull(taskSettings);
3239
}
3340

3441
@Override
@@ -43,6 +50,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4350
builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT);
4451
}
4552

53+
if (taskSettings.cohereTruncation() != null) {
54+
builder.field(TRUNCATE, taskSettings.cohereTruncation().name());
55+
}
56+
4657
builder.endObject();
4758
return builder;
4859
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public static ToXContent createEntity(
3939
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
4040
}
4141
case COHERE -> {
42-
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType);
42+
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
4343
}
4444
default -> {
4545
return null;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ protected void executeRequest(AmazonBedrockBaseClient client) {
7676

7777
@Override
7878
public Request truncate() {
79+
if (provider == AmazonBedrockProvider.COHERE) {
80+
return this; // Cohere has its own truncation logic
81+
}
7982
var truncatedInput = truncator.truncate(truncationResult.input());
8083
return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout);
8184
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public class AmazonBedrockConstants {
1919
public static final String TOP_K_FIELD = "top_k";
2020
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";
2121

22+
public static final String TRUNCATE_FIELD = "truncate";
23+
2224
public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0;
2325
public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0;
2426

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ private static AmazonBedrockModel createModel(
303303
context
304304
);
305305
checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider());
306+
checkTaskSettingsForTextEmbeddingModel(model);
306307
return model;
307308
}
308309
case COMPLETION -> {
@@ -368,6 +369,17 @@ private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvide
368369
}
369370
}
370371

372+
private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) {
373+
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) {
374+
throw new ElasticsearchStatusException(
375+
"The [{}] task type for provider [{}] does not allow [truncate] field",
376+
RestStatus.BAD_REQUEST,
377+
TaskType.TEXT_EMBEDDING,
378+
model.provider()
379+
);
380+
}
381+
}
382+
371383
private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) {
372384
var taskSettings = model.getTaskSettings();
373385
if (taskSettings.topK() != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77

88
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
99

10-
import org.elasticsearch.common.ValidationException;
1110
import org.elasticsearch.inference.ChunkingSettings;
12-
import org.elasticsearch.inference.EmptyTaskSettings;
1311
import org.elasticsearch.inference.Model;
1412
import org.elasticsearch.inference.ModelConfigurations;
1513
import org.elasticsearch.inference.ModelSecrets;
1614
import org.elasticsearch.inference.ServiceSettings;
17-
import org.elasticsearch.inference.TaskSettings;
1815
import org.elasticsearch.inference.TaskType;
1916
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
2017
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
2825

2926
public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
3027
if (taskSettings != null && taskSettings.isEmpty() == false) {
31-
// no task settings allowed
32-
var validationException = new ValidationException();
33-
validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings");
34-
throw validationException;
28+
var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings);
29+
return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings);
3530
}
3631

3732
return embeddingsModel;
@@ -52,7 +47,7 @@ public AmazonBedrockEmbeddingsModel(
5247
taskType,
5348
service,
5449
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
55-
new EmptyTaskSettings(),
50+
AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings),
5651
chunkingSettings,
5752
AwsSecretSettings.fromMap(secretSettings)
5853
);
@@ -63,12 +58,12 @@ public AmazonBedrockEmbeddingsModel(
6358
TaskType taskType,
6459
String service,
6560
AmazonBedrockEmbeddingsServiceSettings serviceSettings,
66-
TaskSettings taskSettings,
61+
AmazonBedrockEmbeddingsTaskSettings taskSettings,
6762
ChunkingSettings chunkingSettings,
6863
AwsSecretSettings secrets
6964
) {
7065
super(
71-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
66+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
7267
new ModelSecrets(secrets)
7368
);
7469
}
@@ -77,6 +72,10 @@ public AmazonBedrockEmbeddingsModel(Model model, ServiceSettings serviceSettings
7772
super(model, serviceSettings);
7873
}
7974

75+
public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) {
76+
super(model, taskSettings);
77+
}
78+
8079
@Override
8180
public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) {
8281
return creator.create(this, taskSettings);
@@ -86,4 +85,9 @@ public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, O
8685
public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() {
8786
return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings();
8887
}
88+
89+
@Override
90+
public AmazonBedrockEmbeddingsTaskSettings getTaskSettings() {
91+
return (AmazonBedrockEmbeddingsTaskSettings) super.getTaskSettings();
92+
}
8993
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.ValidationException;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.core.Nullable;
16+
import org.elasticsearch.inference.ModelConfigurations;
17+
import org.elasticsearch.inference.TaskSettings;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
20+
21+
import java.io.IOException;
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
25+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
26+
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
27+
28+
public record AmazonBedrockEmbeddingsTaskSettings(@Nullable CohereTruncation cohereTruncation) implements TaskSettings {
29+
public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((CohereTruncation) null);
30+
public static final String NAME = "amazon_bedrock_embeddings_task_settings";
31+
32+
public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
33+
if (map == null || map.isEmpty()) {
34+
return EMPTY;
35+
}
36+
37+
ValidationException validationException = new ValidationException();
38+
39+
var cohereTruncation = extractOptionalEnum(
40+
map,
41+
TRUNCATE_FIELD,
42+
ModelConfigurations.TASK_SETTINGS,
43+
CohereTruncation::fromString,
44+
CohereTruncation.ALL,
45+
validationException
46+
);
47+
48+
if (validationException.validationErrors().isEmpty() == false) {
49+
throw validationException;
50+
}
51+
52+
return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation);
53+
}
54+
55+
public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException {
56+
this(in.readOptionalEnum(CohereTruncation.class));
57+
}
58+
59+
@Override
60+
public boolean isEmpty() {
61+
return cohereTruncation() == null;
62+
}
63+
64+
@Override
65+
public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
66+
var newTaskSettings = fromMap(new HashMap<>(newSettings));
67+
68+
return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation()));
69+
}
70+
71+
private static <T> T firstNonNullOrNull(T first, T second) {
72+
return first != null ? first : second;
73+
}
74+
75+
@Override
76+
public String getWriteableName() {
77+
return NAME;
78+
}
79+
80+
@Override
81+
public TransportVersion getMinimalSupportedVersion() {
82+
return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS_8_19;
83+
}
84+
85+
@Override
86+
public void writeTo(StreamOutput out) throws IOException {
87+
out.writeOptionalEnum(cohereTruncation());
88+
}
89+
90+
@Override
91+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
92+
builder.startObject();
93+
if (cohereTruncation != null) {
94+
builder.field(TRUNCATE_FIELD, cohereTruncation);
95+
}
96+
return builder.endObject();
97+
}
98+
}

0 commit comments

Comments
 (0)