Skip to content

Commit a06fd59

Browse files
Add Amazon Bedrock Unified Chat Completions support
1 parent 9848a60 commit a06fd59

File tree

12 files changed

+430
-8
lines changed

12 files changed

+430
-8
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
157157
providersFor(TaskType.CHAT_COMPLETION),
158158
containsInAnyOrder(
159159
List.of(
160+
"amazonbedrock",
160161
"deepseek",
161162
"elastic",
162163
"openai",

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockChatCompletionRequestManager.java

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@
1919
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
2020
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2121
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
22+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
2223
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel;
2324
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionEntityFactory;
2425
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequest;
26+
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionEntityFactory;
27+
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest;
2528
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.completion.AmazonBedrockChatCompletionResponseHandler;
2629

2730
import java.util.function.Supplier;
2831

32+
import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;
33+
2934
public class AmazonBedrockChatCompletionRequestManager extends AmazonBedrockRequestManager {
3035
private static final Logger logger = LogManager.getLogger(AmazonBedrockChatCompletionRequestManager.class);
3136
private final AmazonBedrockChatCompletionModel model;
@@ -46,9 +51,45 @@ public void execute(
4651
Supplier<Boolean> hasRequestCompletedFunction,
4752
ActionListener<InferenceServiceResults> listener
4853
) {
49-
var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class);
50-
var inputs = chatCompletionInput.getInputs();
51-
var stream = chatCompletionInput.stream();
54+
switch (inferenceInputs) {
55+
case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener);
56+
case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener);
57+
default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
58+
}
59+
}
60+
61+
private void execute(
62+
UnifiedChatInput inferenceInputs,
63+
RequestSender requestSender,
64+
Supplier<Boolean> hasRequestCompletedFunction,
65+
ActionListener<InferenceServiceResults> listener
66+
) {
67+
var inputs = inferenceInputs.getRequest();
68+
var stream = inferenceInputs.stream();
69+
var requestEntity = AmazonBedrockUnifiedChatCompletionEntityFactory.createEntity(model, inputs);
70+
var request = new AmazonBedrockUnifiedChatCompletionRequest(model, requestEntity, timeout, stream);
71+
var responseHandler = new AmazonBedrockChatCompletionResponseHandler();
72+
73+
try {
74+
requestSender.send(logger, request, hasRequestCompletedFunction, responseHandler, listener);
75+
} catch (Exception e) {
76+
var errorMessage = Strings.format(
77+
"Failed to send [completion] request from inference entity id [%s]",
78+
request.getInferenceEntityId()
79+
);
80+
logger.warn(errorMessage, e);
81+
listener.onFailure(new ElasticsearchException(errorMessage, e));
82+
}
83+
}
84+
85+
private void execute(
86+
ChatCompletionInput inferenceInputs,
87+
RequestSender requestSender,
88+
Supplier<Boolean> hasRequestCompletedFunction,
89+
ActionListener<InferenceServiceResults> listener
90+
) {
91+
var inputs = inferenceInputs.getInputs();
92+
var stream = inferenceInputs.stream();
5293
var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, inputs);
5394
var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream);
5495
var responseHandler = new AmazonBedrockChatCompletionResponseHandler();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public final class AmazonBedrockProviderCapabilities {
5959

6060
public static boolean providerAllowsTaskType(AmazonBedrockProvider provider, TaskType taskType) {
6161
switch (taskType) {
62-
case COMPLETION -> {
62+
case COMPLETION, CHAT_COMPLETION -> {
6363
return chatCompletionProviders.contains(provider);
6464
}
6565
case TEXT_EMBEDDING -> {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3535
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3636
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
37+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
3738
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3839
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3940
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
@@ -64,7 +65,6 @@
6465
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
6566
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
6667
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
67-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6868
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD;
6969
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD;
7070
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD;
@@ -77,10 +77,12 @@
7777
public class AmazonBedrockService extends SenderService {
7878
public static final String NAME = "amazonbedrock";
7979
private static final String SERVICE_NAME = "Amazon Bedrock";
80+
public static final String COMPLETION_ERROR_PREFIX = "Amazon Bedrock chat completion";
8081

8182
private final Sender amazonBedrockSender;
8283

83-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
84+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
85+
TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
8486

8587
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
8688
InputType.INGEST,
@@ -118,7 +120,7 @@ protected void doUnifiedCompletionInfer(
118120
TimeValue timeout,
119121
ActionListener<InferenceServiceResults> listener
120122
) {
121-
throwUnsupportedUnifiedCompletionOperation(NAME);
123+
infer(model, inputs, null, timeout, listener);
122124
}
123125

124126
@Override
@@ -128,6 +130,16 @@ protected void doInfer(
128130
Map<String, Object> taskSettings,
129131
TimeValue timeout,
130132
ActionListener<InferenceServiceResults> listener
133+
) {
134+
infer(model, inputs, taskSettings, timeout, listener);
135+
}
136+
137+
private void infer(
138+
Model model,
139+
InferenceInputs inputs,
140+
Map<String, Object> taskSettings,
141+
TimeValue timeout,
142+
ActionListener<InferenceServiceResults> listener
131143
) {
132144
var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout);
133145
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
@@ -303,6 +315,19 @@ private static AmazonBedrockModel createModel(
303315
checkTaskSettingsForTextEmbeddingModel(model);
304316
return model;
305317
}
318+
case CHAT_COMPLETION -> {
319+
var model = new AmazonBedrockChatCompletionModel(
320+
inferenceEntityId,
321+
taskType,
322+
NAME,
323+
serviceSettings,
324+
taskSettings,
325+
secretSettings,
326+
context
327+
);
328+
checkProviderForTask(TaskType.CHAT_COMPLETION, model.provider());
329+
return model;
330+
}
306331
case COMPLETION -> {
307332
var model = new AmazonBedrockChatCompletionModel(
308333
inferenceEntityId,
@@ -328,7 +353,7 @@ public TransportVersion getMinimalSupportedVersion() {
328353

329354
@Override
330355
public Set<TaskType> supportedStreamingTasks() {
331-
return COMPLETION_ONLY;
356+
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
332357
}
333358

334359
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecuteOnlyRequestSender.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
1919
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest;
2020
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequest;
21+
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest;
2122
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockEmbeddingsRequest;
2223
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponseHandler;
2324

@@ -83,6 +84,16 @@ protected AmazonBedrockExecutor createExecutor(
8384
clientCache
8485
);
8586
}
87+
case CHAT_COMPLETION -> {
88+
return new AmazonBedrockUnifiedChatCompletionExecutor(
89+
(AmazonBedrockUnifiedChatCompletionRequest) awsRequest,
90+
awsResponse,
91+
logger,
92+
hasRequestTimedOutFunction,
93+
listener,
94+
clientCache
95+
);
96+
}
8697
case TEXT_EMBEDDING -> {
8798
return new AmazonBedrockEmbeddingsExecutor(
8899
(AmazonBedrockEmbeddingsRequest) awsRequest,
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.amazonbedrock.client;
9+
10+
import org.apache.logging.log4j.Logger;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.inference.InferenceServiceResults;
13+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
14+
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest;
15+
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponseHandler;
16+
17+
import java.util.function.Supplier;
18+
19+
public class AmazonBedrockUnifiedChatCompletionExecutor extends AmazonBedrockExecutor {
20+
private final AmazonBedrockUnifiedChatCompletionRequest chatCompletionRequest;
21+
22+
protected AmazonBedrockUnifiedChatCompletionExecutor(
23+
AmazonBedrockUnifiedChatCompletionRequest request,
24+
AmazonBedrockResponseHandler responseHandler,
25+
Logger logger,
26+
Supplier<Boolean> hasRequestCompletedFunction,
27+
ActionListener<InferenceServiceResults> inferenceResultsListener,
28+
AmazonBedrockClientCache clientCache
29+
) {
30+
super(request, responseHandler, logger, hasRequestCompletedFunction, inferenceResultsListener, clientCache);
31+
this.chatCompletionRequest = request;
32+
}
33+
34+
@Override
35+
protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) {
36+
if (chatCompletionRequest.isStreaming()) {
37+
var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient);
38+
inferenceResultsListener.onResponse(new StreamingChatCompletionResults(publisher));
39+
}
40+
}
41+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
import org.elasticsearch.inference.ModelSecrets;
1313
import org.elasticsearch.inference.TaskSettings;
1414
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.inference.UnifiedCompletionRequest;
1516
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
1617
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1718
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
1819
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel;
1920
import org.elasticsearch.xpack.inference.services.amazonbedrock.action.AmazonBedrockActionVisitor;
2021

2122
import java.util.Map;
23+
import java.util.Objects;
2224

2325
public class AmazonBedrockChatCompletionModel extends AmazonBedrockModel {
2426

@@ -32,6 +34,28 @@ public static AmazonBedrockChatCompletionModel of(AmazonBedrockChatCompletionMod
3234
return new AmazonBedrockChatCompletionModel(completionModel, taskSettingsToUse);
3335
}
3436

37+
38+
public static AmazonBedrockChatCompletionModel of(AmazonBedrockChatCompletionModel model, UnifiedCompletionRequest request) {
39+
if (request.model() == null) {
40+
return model;
41+
}
42+
var originalModelServiceSettings = model.getServiceSettings();
43+
var overriddenServiceSettings = new AmazonBedrockChatCompletionServiceSettings(
44+
originalModelServiceSettings.region(),
45+
Objects.requireNonNull(request.model(), originalModelServiceSettings.modelId()),
46+
originalModelServiceSettings.provider(),
47+
originalModelServiceSettings.rateLimitSettings()
48+
);
49+
return new AmazonBedrockChatCompletionModel(
50+
model.getInferenceEntityId(),
51+
model.getTaskType(),
52+
model.getConfigurations().getService(),
53+
overriddenServiceSettings,
54+
model.getTaskSettings(),
55+
model.getSecretSettings()
56+
);
57+
}
58+
3559
public AmazonBedrockChatCompletionModel(
3660
String inferenceEntityId,
3761
TaskType taskType,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockConverseUtils.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion;
99

10+
import org.elasticsearch.inference.UnifiedCompletionRequest;
11+
1012
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
1113
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
1214
import software.amazon.awssdk.services.bedrockruntime.model.Message;
@@ -28,6 +30,15 @@ public static List<Message> getConverseMessageList(List<String> texts) {
2830
.toList();
2931
}
3032

33+
public static List<Message> getUnifiedConverseMessageList(List<UnifiedCompletionRequest.Message> messages) {
34+
return messages.stream()
35+
.map(message -> Message.builder().role(message.role())
36+
.content(ContentBlock.builder()
37+
.text(message.content().toString())
38+
.build()).build())
39+
.toList();
40+
}
41+
3142
public static Optional<InferenceConfiguration> inferenceConfig(AmazonBedrockConverseRequestEntity request) {
3243
if (request.temperature() != null || request.topP() != null || request.maxTokenCount() != null) {
3344
var builder = InferenceConfiguration.builder();
@@ -47,6 +58,25 @@ public static Optional<InferenceConfiguration> inferenceConfig(AmazonBedrockConv
4758
return Optional.empty();
4859
}
4960

61+
public static Optional<InferenceConfiguration> inferenceConfig(AmazonBedrockUnifiedConverseRequestEntity request) {
62+
if (request.temperature() != null || request.topP() != null || request.maxCompletionTokens() != null) {
63+
var builder = InferenceConfiguration.builder();
64+
if (request.temperature() != null) {
65+
builder.temperature(request.temperature().floatValue());
66+
}
67+
68+
if (request.topP() != null) {
69+
builder.topP(request.topP().floatValue());
70+
}
71+
72+
if (request.maxCompletionTokens() != null) {
73+
builder.maxTokens(Math.toIntExact(request.maxCompletionTokens()));
74+
}
75+
return Optional.of(builder.build());
76+
}
77+
return Optional.empty();
78+
}
79+
5080
@Nullable
5181
public static List<String> additionalTopK(@Nullable Double topK) {
5282
if (topK == null) {
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion;
9+
10+
import org.elasticsearch.inference.UnifiedCompletionRequest;
11+
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel;
12+
13+
import java.util.Objects;
14+
15+
public class AmazonBedrockUnifiedChatCompletionEntityFactory {
16+
public static AmazonBedrockUnifiedConverseRequestEntity createEntity(
17+
AmazonBedrockChatCompletionModel model, UnifiedCompletionRequest request) {
18+
Objects.requireNonNull(model);
19+
Objects.requireNonNull(request);
20+
var serviceSettings = model.getServiceSettings();
21+
22+
var messages = request.messages().stream()
23+
.map(message -> new UnifiedCompletionRequest.Message(
24+
message.content(),
25+
toBedrockRole(message.role()),
26+
message.toolCallId(),
27+
message.toolCalls()
28+
))
29+
.toList();
30+
31+
switch (serviceSettings.provider()) {
32+
case ANTHROPIC, AI21LABS, AMAZONTITAN, COHERE, META, MISTRAL -> {
33+
return new AmazonBedrockUnifiedConverseRequestEntity(
34+
messages,
35+
request.model(),
36+
request.maxCompletionTokens(),
37+
request.stop(),
38+
request.temperature(),
39+
request.toolChoice(),
40+
request.tools(),
41+
request.topP()
42+
);
43+
}
44+
default -> {
45+
return null;
46+
}
47+
}
48+
}
49+
50+
private static String toBedrockRole(String role) {
51+
return role == null ? "user" : role;
52+
}
53+
}

0 commit comments

Comments
 (0)