Skip to content

Commit b2e04a3

Browse files
Add Amazon Bedrock Unified Chat Completions support
1 parent cebe88f commit b2e04a3

File tree

7 files changed

+550
-8
lines changed

7 files changed

+550
-8
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockClient.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.services.amazonbedrock.client;
99

10+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
11+
1012
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
1113
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
1214
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
@@ -26,6 +28,9 @@ public interface AmazonBedrockClient {
2628
Flow.Publisher<? extends InferenceServiceResults.Result> converseStream(ConverseStreamRequest converseStreamRequest)
2729
throws ElasticsearchException;
2830

31+
Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> converseUnifiedStream(ConverseStreamRequest request)
32+
throws ElasticsearchException;
33+
2934
void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener<InvokeModelResponse> responseListener)
3035
throws ElasticsearchException;
3136

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockInferenceClient.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.services.amazonbedrock.client;
99

10+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
11+
1012
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
1113
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
1214
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
@@ -101,6 +103,20 @@ public Flow.Publisher<? extends InferenceServiceResults.Result> converseStream(C
101103
return awsResponseProcessor;
102104
}
103105

106+
@Override
107+
public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> converseUnifiedStream(ConverseStreamRequest request)
108+
throws ElasticsearchException {
109+
var awsResponseProcessor = new AmazonBedrockUnifiedStreamingChatProcessor(threadPool);
110+
internalClient.converseStream(
111+
request,
112+
ConverseStreamResponseHandler.builder().subscriber(() -> FlowAdapters.toSubscriber(awsResponseProcessor)).build()
113+
).exceptionally(e -> {
114+
awsResponseProcessor.onError(e);
115+
return null; // Void
116+
});
117+
return awsResponseProcessor;
118+
}
119+
104120
private void onFailure(ActionListener<?> listener, Throwable t, String method) {
105121
ExceptionsHelper.maybeDieOnAnotherThread(t);
106122
var unwrappedException = t;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedChatCompletionExecutor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import org.apache.logging.log4j.Logger;
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.inference.InferenceServiceResults;
13-
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
13+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
1414
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest;
1515
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponseHandler;
1616

@@ -35,7 +35,7 @@ protected AmazonBedrockUnifiedChatCompletionExecutor(
3535
protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) {
3636
if (chatCompletionRequest.isStreaming()) {
3737
var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient);
38-
inferenceResultsListener.onResponse(new StreamingChatCompletionResults(publisher));
38+
inferenceResultsListener.onResponse(new StreamingUnifiedChatCompletionResults(publisher));
3939
}
4040
}
4141
}

0 commit comments

Comments
 (0)