diff --git a/server/src/main/resources/transport/upper_bounds/8.18.csv b/server/src/main/resources/transport/upper_bounds/8.18.csv index 4eb5140004ea6..266bfbbd3bf78 100644 --- a/server/src/main/resources/transport/upper_bounds/8.18.csv +++ b/server/src/main/resources/transport/upper_bounds/8.18.csv @@ -1 +1 @@ -initial_elasticsearch_8_18_6,8840008 +transform_check_for_dangling_tasks,8840011 diff --git a/server/src/main/resources/transport/upper_bounds/8.19.csv b/server/src/main/resources/transport/upper_bounds/8.19.csv index 476468b203875..3600b3f8c633a 100644 --- a/server/src/main/resources/transport/upper_bounds/8.19.csv +++ b/server/src/main/resources/transport/upper_bounds/8.19.csv @@ -1 +1 @@ -initial_elasticsearch_8_19_3,8841067 +transform_check_for_dangling_tasks,8841070 diff --git a/server/src/main/resources/transport/upper_bounds/9.0.csv b/server/src/main/resources/transport/upper_bounds/9.0.csv index f8f50cc6d7839..c11e6837bb813 100644 --- a/server/src/main/resources/transport/upper_bounds/9.0.csv +++ b/server/src/main/resources/transport/upper_bounds/9.0.csv @@ -1 +1 @@ -initial_elasticsearch_9_0_6,9000015 +transform_check_for_dangling_tasks,9000018 diff --git a/server/src/main/resources/transport/upper_bounds/9.1.csv b/server/src/main/resources/transport/upper_bounds/9.1.csv index 5a65f2e578156..80b97d85f7511 100644 --- a/server/src/main/resources/transport/upper_bounds/9.1.csv +++ b/server/src/main/resources/transport/upper_bounds/9.1.csv @@ -1 +1 @@ -initial_elasticsearch_9_1_4,9112007 +transform_check_for_dangling_tasks,9112009 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index e24f914a1d1ca..2147eab66c207 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -ml_inference_endpoint_cache,9157000 +initial_9.2.0,9185000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv new file mode 100644 index 0000000000000..2147eab66c207 --- /dev/null +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -0,0 +1 @@ +initial_9.2.0,9185000 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 21e933292a0ed..674bbaeb218e0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -163,6 +163,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { containsInAnyOrder( List.of( "ai21", + "amazonbedrock", "llama", "deepseek", "elastic", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockChatCompletionRequestManager.java index 4cdb107577b56..8e528d7c5198c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockChatCompletionRequestManager.java @@ -19,13 +19,18 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionEntityFactory; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionEntityFactory; +import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.amazonbedrock.response.completion.AmazonBedrockChatCompletionResponseHandler; import java.util.function.Supplier; +import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException; + public class AmazonBedrockChatCompletionRequestManager extends AmazonBedrockRequestManager { private static final Logger logger = LogManager.getLogger(AmazonBedrockChatCompletionRequestManager.class); private final AmazonBedrockChatCompletionModel model; @@ -46,9 +51,45 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); - var inputs = chatCompletionInput.getInputs(); - var stream = chatCompletionInput.stream(); + switch (inferenceInputs) { + case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener); + case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener); + default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class); + } + } + + private void execute( + UnifiedChatInput inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var inputs = inferenceInputs.getRequest(); + var stream = inferenceInputs.stream(); + var requestEntity = AmazonBedrockUnifiedChatCompletionEntityFactory.createEntity(model, inputs); + var request = new AmazonBedrockUnifiedChatCompletionRequest(model, requestEntity, timeout, stream); + var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); + + try { + requestSender.send(logger, request, hasRequestCompletedFunction, responseHandler, listener); + } catch (Exception e) { + var errorMessage = Strings.format( + "Failed to send [completion] request from inference entity id [%s]", + request.getInferenceEntityId() + ); + logger.warn(errorMessage, e); + listener.onFailure(new ElasticsearchException(errorMessage, e)); + } + } + + private void execute( + ChatCompletionInput inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var inputs = inferenceInputs.getInputs(); + var stream = inferenceInputs.stream(); var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, inputs); var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream); var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java index 28b10ef294bda..e6242995a7b1d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java @@ -59,7 +59,7 @@ public final class AmazonBedrockProviderCapabilities { public static boolean providerAllowsTaskType(AmazonBedrockProvider provider, TaskType taskType) { switch (taskType) { - case COMPLETION -> { + case COMPLETION, CHAT_COMPLETION -> { return chatCompletionProviders.contains(provider); } case TEXT_EMBEDDING -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index c2b0ae8e69c37..67d3c8c5510a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -64,7 +64,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; @@ -77,10 +76,15 @@ public class AmazonBedrockService extends SenderService { public static final String NAME = "amazonbedrock"; private static final String SERVICE_NAME = "Amazon Bedrock"; + public static final String COMPLETION_ERROR_PREFIX = "Amazon Bedrock chat completion"; private final Sender amazonBedrockSender; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION + ); private static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, @@ -118,7 +122,7 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - throwUnsupportedUnifiedCompletionOperation(NAME); + infer(model, inputs, null, timeout, listener); } @Override @@ -128,6 +132,16 @@ protected void doInfer( Map taskSettings, TimeValue timeout, ActionListener listener + ) { + infer(model, inputs, taskSettings, timeout, listener); + } + + private void infer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener ) { var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout); if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { @@ -303,6 +317,19 @@ private static AmazonBedrockModel createModel( checkTaskSettingsForTextEmbeddingModel(model); return model; } + case CHAT_COMPLETION -> { + var model = new AmazonBedrockChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + checkProviderForTask(TaskType.CHAT_COMPLETION, model.provider()); + return model; + } case COMPLETION -> { var model = new AmazonBedrockChatCompletionModel( inferenceEntityId, @@ -328,7 +355,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public Set supportedStreamingTasks() { - return COMPLETION_ONLY; + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockClient.java index 799efbad517d5..9e4b4687ec0cf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockClient.java @@ -16,6 +16,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import java.time.Instant; import java.util.concurrent.Flow; @@ -26,6 +27,9 @@ public interface AmazonBedrockClient { Flow.Publisher converseStream(ConverseStreamRequest converseStreamRequest) throws ElasticsearchException; + Flow.Publisher converseUnifiedStream(ConverseStreamRequest request) + throws ElasticsearchException; + void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) throws ElasticsearchException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecuteOnlyRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecuteOnlyRequestSender.java index 50fd9db6d1c44..39aa190e8f65c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecuteOnlyRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecuteOnlyRequestSender.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockEmbeddingsRequest; import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponseHandler; @@ -83,6 +84,16 @@ protected AmazonBedrockExecutor createExecutor( clientCache ); } + case CHAT_COMPLETION -> { + return new AmazonBedrockUnifiedChatCompletionExecutor( + (AmazonBedrockUnifiedChatCompletionRequest) awsRequest, + awsResponse, + logger, + hasRequestTimedOutFunction, + listener, + clientCache + ); + } case TEXT_EMBEDDING -> { return new AmazonBedrockEmbeddingsExecutor( (AmazonBedrockEmbeddingsRequest) awsRequest, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockInferenceClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockInferenceClient.java index f19866febea2c..b3cbf4d5d35e8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockInferenceClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockInferenceClient.java @@ -31,6 +31,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; import org.reactivestreams.FlowAdapters; import org.slf4j.LoggerFactory; @@ -101,6 +102,20 @@ public Flow.Publisher converseStream(C return awsResponseProcessor; } + @Override + public Flow.Publisher converseUnifiedStream(ConverseStreamRequest request) + throws ElasticsearchException { + var awsResponseProcessor = new AmazonBedrockUnifiedStreamingChatProcessor(threadPool); + internalClient.converseStream( + request, + ConverseStreamResponseHandler.builder().subscriber(() -> FlowAdapters.toSubscriber(awsResponseProcessor)).build() + ).exceptionally(e -> { + awsResponseProcessor.onError(e); + return null; // Void + }); + return awsResponseProcessor; + } + private void onFailure(ActionListener listener, Throwable t, String method) { ExceptionsHelper.maybeDieOnAnotherThread(t); var unwrappedException = t; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedChatCompletionExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedChatCompletionExecutor.java new file mode 100644 index 0000000000000..b4b7f5b1e1d89 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedChatCompletionExecutor.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.client; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponseHandler; + +import java.util.function.Supplier; + +public class AmazonBedrockUnifiedChatCompletionExecutor extends AmazonBedrockExecutor { + private final AmazonBedrockUnifiedChatCompletionRequest chatCompletionRequest; + + protected AmazonBedrockUnifiedChatCompletionExecutor( + AmazonBedrockUnifiedChatCompletionRequest request, + AmazonBedrockResponseHandler responseHandler, + Logger logger, + Supplier hasRequestCompletedFunction, + ActionListener inferenceResultsListener, + AmazonBedrockClientCache clientCache + ) { + super(request, responseHandler, logger, hasRequestCompletedFunction, inferenceResultsListener, clientCache); + this.chatCompletionRequest = request; + } + + @Override + protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) { + if (chatCompletionRequest.isStreaming()) { + var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient); + inferenceResultsListener.onResponse(new StreamingUnifiedChatCompletionResults(publisher)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedStreamingChatProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedStreamingChatProcessor.java new file mode 100644 index 0000000000000..d650ef9dfe6de --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedStreamingChatProcessor.java @@ -0,0 +1,469 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.client; + +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; +import software.amazon.awssdk.services.bedrockruntime.model.MessageStartEvent; +import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent; +import software.amazon.awssdk.services.bedrockruntime.model.StopReason; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Strings; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; + +import java.util.ArrayDeque; +import java.util.List; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; + +@SuppressWarnings("checkstyle:LineLength") +class AmazonBedrockUnifiedStreamingChatProcessor + implements + Flow.Processor { + private static final Logger logger = LogManager.getLogger(AmazonBedrockStreamingChatProcessor.class); + + private final AtomicReference error = new AtomicReference<>(null); + private final AtomicLong demand = new AtomicLong(0); + private final AtomicBoolean isDone = new AtomicBoolean(false); + private final AtomicBoolean onCompleteCalled = new AtomicBoolean(false); + private final AtomicBoolean onErrorCalled = new AtomicBoolean(false); + private final ThreadPool threadPool; + private volatile Flow.Subscriber downstream; + private volatile Flow.Subscription upstream; + + AmazonBedrockUnifiedStreamingChatProcessor(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + @Override + public void subscribe(Flow.Subscriber subscriber) { + if (downstream == null) { + downstream = subscriber; + downstream.onSubscribe(new StreamSubscription()); + } else { + subscriber.onError(new IllegalStateException("Subscriber already set.")); + } + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + if (upstream == null) { + upstream = subscription; + var currentRequestCount = demand.getAndUpdate(i -> 0); + if (currentRequestCount > 0) { + upstream.request(currentRequestCount); + } + } else { + subscription.cancel(); + } + } + + @Override + public void onNext(ConverseStreamOutput item) { + var chunks = new ArrayDeque(1); + + var eventType = item.sdkEventType(); + switch (eventType) { + case ConverseStreamOutput.EventType.MESSAGE_START -> { + demand.set(0); // reset demand before we fork to another thread + item.accept( + ConverseStreamResponseHandler.Visitor.builder().onMessageStart(event -> handleMessageStart(event, chunks)).build() + ); + return; + } + case ConverseStreamOutput.EventType.MESSAGE_STOP -> { + demand.set(0); // reset demand before we fork to another thread + item.accept( + ConverseStreamResponseHandler.Visitor.builder().onMessageStop(event -> handleMessageStop(event, chunks)).build() + ); + return; + } + case ConverseStreamOutput.EventType.CONTENT_BLOCK_START -> { + demand.set(0); // reset demand before we fork to another thread + item.accept( + ConverseStreamResponseHandler.Visitor.builder() + .onContentBlockStart(event -> handleContentBlockStart(event, chunks)) + .build() + ); + return; + } + case ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA -> { + demand.set(0); // reset demand before we fork to another thread + item.accept( + ConverseStreamResponseHandler.Visitor.builder() + .onContentBlockDelta(event -> handleContentBlockDelta(event, chunks)) + .build() + ); + return; + } + case ConverseStreamOutput.EventType.METADATA -> { + demand.set(0); // reset demand before we fork to another thread + item.accept(ConverseStreamResponseHandler.Visitor.builder().onMetadata(event -> handleMetadata(event, chunks)).build()); + return; + } + case ConverseStreamOutput.EventType.CONTENT_BLOCK_STOP -> { + demand.set(0); // reset demand before we fork to another thread + item.accept(ConverseStreamResponseHandler.Visitor.builder().onContentBlockStop(event -> Stream.empty()).build()); + return; + } + default -> { + logger.debug("Unknown event type [{}] for line [{}].", eventType, item); + } + } + + if (item.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA) { + + } else { + upstream.request(1); + } + } + + private void handleMessageStart(MessageStartEvent event, ArrayDeque chunks) { + runOnUtilityThreadPool(() -> { + try { + var messageStart = handleMessageStart(event); + messageStart.forEach(chunks::offer); + } catch (Exception e) { + logger.warn("Failed to parse message start event from Amazon Bedrock provider: {}", event); + } + if (chunks.isEmpty()) { + upstream.request(1); + } else { + downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(chunks)); + } + }); + } + + private void handleMessageStop(MessageStopEvent event, ArrayDeque chunks) { + runOnUtilityThreadPool(() -> { + try { + var messageStop = handleMessageStop(event); + messageStop.forEach(chunks::offer); + } catch (Exception e) { + logger.warn("Failed to parse message stop event from Amazon Bedrock provider: {}", event); + } + if (chunks.isEmpty()) { + upstream.request(1); + } else { + downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(chunks)); + } + }); + } + + private void handleContentBlockStart( + ContentBlockStartEvent event, + ArrayDeque chunks + ) { + try { + var contentBlockStart = handleContentBlockStart(event); + contentBlockStart.forEach(chunks::offer); + } catch (Exception e) { + logger.warn("Failed to parse block start event from Amazon Bedrock provider: {}", event); + } + var results = new StreamingUnifiedChatCompletionResults.Results(chunks); + downstream.onNext(results); + } + + private void handleContentBlockDelta( + ContentBlockDeltaEvent event, + ArrayDeque chunks + ) { + runOnUtilityThreadPool(() -> { + try { + var contentBlockDelta = handleContentBlockDelta(event); + contentBlockDelta.forEach(chunks::offer); + } catch (Exception e) { + logger.warn("Failed to parse content block delta event from Amazon Bedrock provider: {}", event); + } + var results = new StreamingUnifiedChatCompletionResults.Results(chunks); + downstream.onNext(results); + }); + } + + private void handleMetadata( + ConverseStreamMetadataEvent event, + ArrayDeque chunks + ) { + runOnUtilityThreadPool(() -> { + try { + var messageDelta = handleMetadata(event); + messageDelta.forEach(chunks::offer); + } catch (Exception e) { + logger.warn("Failed to parse metadata event from Amazon Bedrock provider: {}", event); + } + var results = new StreamingUnifiedChatCompletionResults.Results(chunks); + downstream.onNext(results); + }); + } + + @Override + public void onError(Throwable amazonBedrockRuntimeException) { + ExceptionsHelper.maybeDieOnAnotherThread(amazonBedrockRuntimeException); + error.set( + new ElasticsearchException( + Strings.format("AmazonBedrock StreamingChatProcessor failure: [%s]", amazonBedrockRuntimeException.getMessage()), + amazonBedrockRuntimeException + ) + ); + if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onErrorCalled.compareAndSet(false, true)) { + runOnUtilityThreadPool(() -> downstream.onError(amazonBedrockRuntimeException)); + } + } + + private boolean checkAndResetDemand() { + return demand.getAndUpdate(i -> 0L) > 0L; + } + + @Override + public void onComplete() { + if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onCompleteCalled.compareAndSet(false, true)) { + downstream.onComplete(); + } + } + + private void runOnUtilityThreadPool(Runnable runnable) { + try { + threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(runnable); + } catch (Exception e) { + logger.error(Strings.format("failed to fork [%s] to utility thread pool", runnable), e); + } + } + + private class StreamSubscription implements Flow.Subscription { + @Override + public void request(long n) { + if (n > 0L) { + demand.updateAndGet(i -> { + var sum = i + n; + return sum >= 0 ? sum : Long.MAX_VALUE; + }); + if (upstream == null) { + // wait for upstream to subscribe before forwarding request + return; + } + if (upstreamIsRunning()) { + requestOnMlThread(n); + } else if (error.get() != null && onErrorCalled.compareAndSet(false, true)) { + downstream.onError(error.get()); + } else if (onCompleteCalled.compareAndSet(false, true)) { + downstream.onComplete(); + } + } else { + cancel(); + downstream.onError(new IllegalStateException("Cannot request a negative number.")); + } + } + + private boolean upstreamIsRunning() { + return isDone.get() == false && error.get() == null; + } + + private void requestOnMlThread(long n) { + var currentThreadPool = EsExecutors.executorName(Thread.currentThread().getName()); + if (UTILITY_THREAD_POOL_NAME.equalsIgnoreCase(currentThreadPool)) { + upstream.request(n); + } else { + runOnUtilityThreadPool(() -> upstream.request(n)); + } + } + + @Override + public void cancel() { + if (upstream != null && upstreamIsRunning()) { + upstream.cancel(); + } + } + } + + /** + * Parse a MessageStartEvent into a ChatCompletionChunk stream + * @param event the MessageStartEvent data + * @return a stream of ChatCompletionChunk + */ + public static Stream handleMessageStart(MessageStartEvent event) { + var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, event.roleAsString(), null); + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, 0); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null); + return Stream.of(chunk); + } + + /** + * Parse a MessageStopEvent into a ChatCompletionChunk stream + * @param event the MessageStopEvent data + * @return a stream of ChatCompletionChunk + */ + public static Stream handleMessageStop(MessageStopEvent event) { + var finishReason = handleFinishReason(event.stopReason()); + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(null, finishReason, 0); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null); + return Stream.of(chunk); + } + + public static Stream processEvent(MessageStopEvent event) { + var finishReason = handleFinishReason(event.stopReason()); + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(null, finishReason, 0); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null); + return Stream.of(chunk); + } + + /** + * This ensures consistent handling of completion termination across different providers. + * For example, both "stop_sequence" and "end_turn" from Bedrock map to the unified "stop" reason. + * @param stopReason the stop reason + * @return a stop reason + */ + public static String handleFinishReason(StopReason stopReason) { + switch (stopReason) { + case StopReason.TOOL_USE -> { + return "FinishReasonToolCalls"; + } + case StopReason.MAX_TOKENS -> { + return "FinishReasonLength"; + } + case StopReason.CONTENT_FILTERED, StopReason.GUARDRAIL_INTERVENED -> { + return "FinishReasonContentFilter"; + } + case StopReason.END_TURN, StopReason.STOP_SEQUENCE -> { + return "FinishReasonStop"; + } + default -> { + logger.debug("unhandled stop reason [{}].", stopReason); + return "FinishReasonStop"; + } + } + } + + public StreamingUnifiedChatCompletionResults.ChatCompletionChunk createBaseChunk() { + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, null, null, "chat.completion.chunk", null); + } + + /** + * processes a tool initialization event from Bedrock + * This occurs when the model first decides to use a tool, providing its name and ID. + * Parse a MessageStartEvent into a ToolCall stream + * @param start the ContentBlockStart data + * @return a ToolCall + */ + private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall handleToolUseStart( + ContentBlockStart start + ) { + var type = start.type(); + var toolUse = start.toolUse(); + var function = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function(null, toolUse.name()); + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 0, + toolUse.toolUseId(), + function, + type.name() + ); + } + + /** + * processes incremental updates to a tool call + * This typically contains the arguments that the model wants to pass to the tool. + * Parse a ContentBlockDelta into a ToolCall stream + * @param delta the ContentBlockDelta data + * @return a ToolCall + */ + private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall handleToolUseDelta( + ContentBlockDelta delta + ) { + var type = delta.type(); + var toolUse = delta.toolUse(); + var function = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function(toolUse.input(), null); + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(0, null, function, type.name()); + } + + /** + * Parse a ContentBlockStartEvent into a ChatCompletionChunk stream + * @param event the content block start data + * @return a stream of ChatCompletionChunk + */ + public static Stream handleContentBlockStart(ContentBlockStartEvent event) { + var index = event.contentBlockIndex(); + var type = event.start().type(); + + switch (type) { + case ContentBlockStart.Type.TOOL_USE -> { + var toolCall = handleToolUseStart(event.start()); + var role = "assistant"; + var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, role, List.of(toolCall)); + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, index); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null); + return Stream.of(chunk); + } + default -> logger.debug("unhandled content block start type [{}].", type); + } + var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, null); + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, index); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null); + return Stream.of(chunk); + } + + /** + * processes incremental content updates + * Parse a ContentBlockDeltaEvent into a ChatCompletionChunk stream + * @param event the event data + * @return a stream of ChatCompletionChunk + */ + public static Stream handleContentBlockDelta(ContentBlockDeltaEvent event) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta delta = null; + var type = event.delta().type(); + var content = event.delta().text(); + + switch (type) { + case ContentBlockDelta.Type.TEXT -> { + delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(content, null, null, null); + } + case ContentBlockDelta.Type.TOOL_USE -> { + var toolCall = handleToolUseDelta(event.delta()); + delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(content, null, null, List.of(toolCall)); + } + default -> logger.debug("unknown content block delta type [{}].", type); + } + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, event.contentBlockIndex()); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null); + return Stream.of(chunk); + } + + /** + * processes usage statistics + * Parse a ConverseStreamMetadataEvent into a ChatCompletionChunk stream + * @param event the event data + * @return a stream of ChatCompletionChunk + */ + public static Stream handleMetadata(ConverseStreamMetadataEvent event) { + var inputTokens = event.usage().inputTokens(); + var outputTokens = event.usage().outputTokens(); + var totalTokens = event.usage().totalTokens(); + var usage = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(outputTokens, inputTokens, totalTokens); + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, null), + null, + 0 + ); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, usage); + return Stream.of(chunk); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java index 68c1c884ab63a..735afbfc1610a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; @@ -19,6 +20,7 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.action.AmazonBedrockActionVisitor; import java.util.Map; +import java.util.Objects; public class AmazonBedrockChatCompletionModel extends AmazonBedrockModel { @@ -32,6 +34,27 @@ public static AmazonBedrockChatCompletionModel of(AmazonBedrockChatCompletionMod return new AmazonBedrockChatCompletionModel(completionModel, taskSettingsToUse); } + public static AmazonBedrockChatCompletionModel of(AmazonBedrockChatCompletionModel model, UnifiedCompletionRequest request) { + if (request.model() == null) { + return model; + } + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new AmazonBedrockChatCompletionServiceSettings( + originalModelServiceSettings.region(), + Objects.requireNonNull(request.model(), originalModelServiceSettings.modelId()), + originalModelServiceSettings.provider(), + originalModelServiceSettings.rateLimitSettings() + ); + return new AmazonBedrockChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getTaskSettings(), + model.getSecretSettings() + ); + } + public AmazonBedrockChatCompletionModel( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockConverseUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockConverseUtils.java index a22d9ca3e850b..23bf649534f28 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockConverseUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockConverseUtils.java @@ -13,6 +13,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; import java.util.List; import java.util.Optional; @@ -28,6 +29,17 @@ public static List getConverseMessageList(List texts) { .toList(); } + public static List getUnifiedConverseMessageList(List messages) { + return messages.stream() + .map( + message -> Message.builder() + .role(message.role()) + .content(ContentBlock.builder().text(message.content().toString()).build()) + .build() + ) + .toList(); + } + public static Optional inferenceConfig(AmazonBedrockConverseRequestEntity request) { if (request.temperature() != null || request.topP() != null || request.maxTokenCount() != null) { var builder = InferenceConfiguration.builder(); @@ -47,6 +59,29 @@ public static Optional inferenceConfig(AmazonBedrockConv return Optional.empty(); } + public static Optional inferenceConfig(AmazonBedrockUnifiedConverseRequestEntity request) { + if (request.temperature() != null || request.topP() != null || request.maxCompletionTokens() != null) { + var builder = InferenceConfiguration.builder(); + if (request.temperature() != null) { + builder.temperature(request.temperature().floatValue()); + } + + if (request.topP() != null) { + builder.topP(request.topP().floatValue()); + } + + if (request.maxCompletionTokens() != null) { + builder.maxTokens(Math.toIntExact(request.maxCompletionTokens())); + } + + if (request.stop() != null) { + builder.stopSequences(request.stop()); + } + return Optional.of(builder.build()); + } + return Optional.empty(); + } + @Nullable public static List additionalTopK(@Nullable Double topK) { if (topK == null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedChatCompletionEntityFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedChatCompletionEntityFactory.java new file mode 100644 index 0000000000000..397073933028f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedChatCompletionEntityFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; + +import java.util.Objects; + +public class AmazonBedrockUnifiedChatCompletionEntityFactory { + public static AmazonBedrockUnifiedConverseRequestEntity createEntity( + AmazonBedrockChatCompletionModel model, + UnifiedCompletionRequest request + ) { + Objects.requireNonNull(model); + Objects.requireNonNull(request); + var serviceSettings = model.getServiceSettings(); + + var messages = request.messages() + .stream() + .map( + message -> new UnifiedCompletionRequest.Message( + message.content(), + toBedrockRole(message.role()), + message.toolCallId(), + message.toolCalls() + ) + ) + .toList(); + + switch (serviceSettings.provider()) { + case ANTHROPIC, AI21LABS, AMAZONTITAN, COHERE, META, MISTRAL -> { + return new AmazonBedrockUnifiedConverseRequestEntity( + messages, + request.model(), + request.maxCompletionTokens(), + request.stop(), + request.temperature(), + request.toolChoice(), + request.tools(), + request.topP() + ); + } + default -> { + return null; + } + } + } + + private static String toBedrockRole(String role) { + return role == null ? "user" : role; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedChatCompletionRequest.java new file mode 100644 index 0000000000000..e007d497051a9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedChatCompletionRequest.java @@ -0,0 +1,184 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion; + +import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; +import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent; +import software.amazon.awssdk.services.bedrockruntime.model.SpecificToolChoice; +import software.amazon.awssdk.services.bedrockruntime.model.Tool; +import software.amazon.awssdk.services.bedrockruntime.model.ToolChoice; +import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema; +import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockBaseClient; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.services.amazonbedrock.response.completion.AmazonBedrockChatCompletionResponseListener; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Flow; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockConverseUtils.getUnifiedConverseMessageList; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockConverseUtils.inferenceConfig; + +public class AmazonBedrockUnifiedChatCompletionRequest extends AmazonBedrockRequest { + public static final String USER_ROLE = "user"; + private final AmazonBedrockUnifiedConverseRequestEntity requestEntity; + private AmazonBedrockChatCompletionResponseListener listener; + private final boolean stream; + + public AmazonBedrockUnifiedChatCompletionRequest( + AmazonBedrockChatCompletionModel model, + AmazonBedrockUnifiedConverseRequestEntity requestEntity, + @Nullable TimeValue timeout, + boolean stream + ) { + super(model, timeout); + this.requestEntity = Objects.requireNonNull(requestEntity); + this.stream = stream; + } + + public Flow.Publisher executeStreamChatCompletionRequest( + AmazonBedrockBaseClient awsBedrockClient + ) throws ExecutionException, InterruptedException { + var converseStreamRequest = ConverseStreamRequest.builder() + .messages(getUnifiedConverseMessageList(requestEntity.messages())) + .modelId(amazonBedrockModel.model()); + + if (requestEntity.tools() != null) { + requestEntity.tools().forEach(tool -> { + try { + converseStreamRequest.toolConfig( + ToolConfiguration.builder() + .tools( + Tool.builder() + .toolSpec( + ToolSpecification.builder() + .name(tool.function().name()) + .description(tool.function().description()) + .inputSchema(ToolInputSchema.fromJson(Document.fromMap(paramToDocumentMap(tool)))) + .build() + ) + .build() + ) + .toolChoice( + ToolChoice.builder().tool(SpecificToolChoice.builder().name(tool.function().name()).build()).build() + ) + .build() + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + inferenceConfig(requestEntity).ifPresent(converseStreamRequest::inferenceConfig); + var response = awsBedrockClient.converseUnifiedStream(converseStreamRequest.build()); + + var toolRequested = new CompletableFuture(); + final String[] toolUseIdHolder = new String[1]; + final StringBuilder toolJsonArgs = new StringBuilder(); + final StringBuilder assistantText = new StringBuilder(); + + var handler = ConverseStreamResponseHandler.builder().onEventStream(es -> es.subscribe(event -> { + switch (event.sdkEventType()) { + case MESSAGE_START: + break; + case CONTENT_BLOCK_START: + var start = ((ContentBlockStartEvent) event).start(); + if (start.toolUse() != null) { + toolUseIdHolder[0] = start.toolUse().toolUseId(); + } + break; + case CONTENT_BLOCK_DELTA: + var delta = ((ContentBlockDeltaEvent) event).delta(); + if (delta.toolUse() != null && delta.toolUse().input() != null) { + toolJsonArgs.append(delta.toolUse().input()); + } + if (delta.text() != null) { + assistantText.append(delta.text()); + } + break; + case MESSAGE_STOP: + var stop = ((MessageStopEvent) event).stopReason(); + if ("tool_use".equalsIgnoreCase(stop.name())) { + toolRequested.complete(true); + } else { + toolRequested.complete(false); + } + break; + default: + } + })).onResponse(r -> toolRequested.complete(true)).onError(toolRequested::completeExceptionally); + + handler.subscriber(converseStreamOutput -> getUnifiedConverseMessageList(requestEntity.messages()).forEach(toolJsonArgs::append)); + + if (Boolean.TRUE.equals(toolRequested.get())) { + toolJsonArgs.toString().contains("args"); + Map result = Map.of("tool_use", toolUseIdHolder[0]); + // var toolResultBlock = ContentBlock + // .fromToolResult(ToolResultContentBlock.builder() + // .document(DocumentBlock.builder().context(result).build())); + + } + inferenceConfig(requestEntity).ifPresent(converseStreamRequest::inferenceConfig); + return awsBedrockClient.converseUnifiedStream(converseStreamRequest.build()); + } + + private Document toDocument(Object value) { + return switch (value) { + case null -> Document.fromNull(); + case String stringValue -> Document.fromString(stringValue); + case Integer numberValue -> Document.fromNumber(numberValue); + case Map mapValue -> { + final Map converted = new HashMap<>(); + for (Map.Entry entry : mapValue.entrySet()) { + converted.put(String.valueOf(entry.getKey()), toDocument(entry.getValue())); + } + yield Document.fromMap(converted); + } + default -> Document.mapBuilder().build(); + }; + } + + private Map paramToDocumentMap(UnifiedCompletionRequest.Tool tool) throws IOException { + Map paramDocuments = new HashMap<>(); + for (Map.Entry entry : tool.function().parameters().entrySet()) { + paramDocuments.put(entry.getKey(), toDocument(entry.getValue())); + } + return paramDocuments; + } + + @Override + protected void executeRequest(AmazonBedrockBaseClient client) {} + + @Override + public TaskType taskType() { + return TaskType.CHAT_COMPLETION; + } + + @Override + public boolean isStreaming() { + return stream; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedConverseRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedConverseRequestEntity.java new file mode 100644 index 0000000000000..c888e354ab721 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedConverseRequestEntity.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.UnifiedCompletionRequest; + +import java.util.List; + +public record AmazonBedrockUnifiedConverseRequestEntity( + List messages, + @Nullable String model, + @Nullable Long maxCompletionTokens, + @Nullable List stop, + @Nullable Float temperature, + @Nullable UnifiedCompletionRequest.ToolChoice toolChoice, + @Nullable List tools, + @Nullable Float topP +) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/completion/AmazonBedrockChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/completion/AmazonBedrockChatCompletionResponseHandler.java index 6c4bd4862aac2..3c5abe9054b11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/completion/AmazonBedrockChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/completion/AmazonBedrockChatCompletionResponseHandler.java @@ -36,4 +36,9 @@ public String getRequestType() { public void acceptChatCompletionResponseObject(ConverseResponse response) { this.responseResult = response; } + + @Override + public boolean canHandleStreamingResponses() { + return true; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index d7eb32861da92..cb74e0d6e4a39 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -32,6 +32,9 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; @@ -41,11 +44,15 @@ import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockMockRequestSender; +import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockRequestSenderTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; @@ -68,7 +75,10 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.cluster.service.TaskExecutorTests.createThreadPool; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.test.ESTestCase.assertThat; +import static org.elasticsearch.test.ESTestCase.terminate; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; @@ -95,16 +105,22 @@ public class AmazonBedrockServiceTests extends InferenceServiceTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; + private HttpClientManager clientManager; @Before public void init() throws Exception { + webServer.start(); threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); } @After public void shutdown() throws IOException { + clientManager.close(); terminate(threadPool); + webServer.close(); } public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException { @@ -1460,4 +1476,54 @@ private Utils.PersistedConfig getPersistedConfigMap( new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) ); } + + public void testDoUnifiedInfer() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {"choices": [{"delta": {"content": "content", "role": "assistant"}, "finish_reason": null, "index": 0, \ + "logprobs": null}], "created": 1718345013, "id": "12345", "model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", \ + "object": "chat.completion.chunk", "system_fingerprint": "fp_1234"} + + data: [DONE] + + """)); + doUnifiedCompletionInfer().hasNoErrors().hasEvent(""" + {"id":"12345","choices":[{"delta":{"content":"content","role":"assistant"},"index":0}],""" + """ + "model":"us.anthropic.claude-3-7-sonnet-20250219-v1:0","object":"chat.completion.chunk"}"""); + } + + private InferenceEventsAssertion doUnifiedCompletionInfer() throws Exception { + var model = AmazonBedrockChatCompletionModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access", + "secret" + ); + + try (var service = createAmazonBedrockService()) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + TIMEOUT, + listener + ); + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + } + } + + private AmazonBedrockService createService() { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + return new AmazonBedrockService( + factory, + AmazonBedrockRequestSenderTests.createSenderFactory(threadPool, Settings.EMPTY), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedStreamingChatProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedStreamingChatProcessorTests.java new file mode 100644 index 0000000000000..88ef4390e2216 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedStreamingChatProcessorTests.java @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.client; + +import software.amazon.awssdk.services.bedrockruntime.model.*; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.concurrent.Flow; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.hamcrest.Matchers.*; +import static org.hamcrest.Matchers.isA; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +public class AmazonBedrockUnifiedStreamingChatProcessorTests extends ESTestCase { + private AmazonBedrockUnifiedStreamingChatProcessor processor; + + @Before + public void setUp() throws Exception { + super.setUp(); + ThreadPool threadPool = mock(); + when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); + processor = new AmazonBedrockUnifiedStreamingChatProcessor(threadPool); + } + + /** + * We do not issue requests on subscribe because the downstream will control the pacing. + */ + public void testOnSubscribeBeforeDownstreamDoesNotRequest() { + var upstream = mock(Flow.Subscription.class); + processor.onSubscribe(upstream); + + verify(upstream, never()).request(anyLong()); + } + + /** + * If the downstream requests data before the upstream is set, when the upstream is set, we will forward the pending requests to it. + */ + public void testOnSubscribeAfterDownstreamRequests() { + var expectedRequestCount = randomLongBetween(1, 500); + Flow.Subscriber subscriber = mock(); + doAnswer(ans -> { + Flow.Subscription sub = ans.getArgument(0); + sub.request(expectedRequestCount); + return null; + }).when(subscriber).onSubscribe(any()); + processor.subscribe(subscriber); + + var upstream = mock(Flow.Subscription.class); + processor.onSubscribe(upstream); + + verify(upstream, times(1)).request(anyLong()); + } + + public void testCancelDuplicateSubscriptions() { + processor.onSubscribe(mock()); + + var upstream = mock(Flow.Subscription.class); + processor.onSubscribe(upstream); + + verify(upstream, times(1)).cancel(); + verifyNoMoreInteractions(upstream); + } + + public void testMultiplePublishesCallsOnError() { + processor.subscribe(mock()); + + Flow.Subscriber subscriber = mock(); + processor.subscribe(subscriber); + + verify(subscriber, times(1)).onError(assertArg(e -> { + assertThat(e, isA(IllegalStateException.class)); + assertThat(e.getMessage(), equalTo("Subscriber already set.")); + })); + } + + private ConverseStreamOutput output(String text) { + ConverseStreamOutput output = mock(); + when(output.sdkEventType()).thenReturn(ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA); + doAnswer(ans -> { + ConverseStreamResponseHandler.Visitor visitor = ans.getArgument(0); + ContentBlockDelta delta = ContentBlockDelta.fromText(text); + ContentBlockDeltaEvent event = ContentBlockDeltaEvent.builder().delta(delta).build(); + visitor.visitContentBlockDelta(event); + return null; + }).when(output).accept(any()); + return output; + } + + private void verifyText(Flow.Subscriber downstream, String expectedText) { + verify(downstream, times(1)).onNext(assertArg(results -> { + assertThat(results, notNullValue()); + assertThat(results.chunks().size(), equalTo(1)); + // assertThat(results.chunks().getFirst().choices().getFirst(), equalTo(expectedText)); + })); + } + + public void verifyCompleteBeforeRequest() { + processor.onComplete(); + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + verify(downstream, times(1)).onComplete(); + } + + public void verifyCompleteAfterRequest() { + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + processor.onComplete(); + verify(downstream, times(1)).onComplete(); + } + + public void verifyOnErrorBeforeRequest() { + var expectedError = BedrockRuntimeException.builder().message("ahhhhhh").build(); + processor.onError(expectedError); + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + verify(downstream, times(1)).onError(assertArg(e -> { + assertThat(e, isA(ElasticsearchException.class)); + assertThat(e.getCause(), is(expectedError)); + })); + } + + public void verifyOnErrorAfterRequest() { + var expectedError = BedrockRuntimeException.builder().message("ahhhhhh").build(); + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + processor.onError(expectedError); + verify(downstream, times(1)).onError(assertArg(e -> { + assertThat(e, isA(ElasticsearchException.class)); + assertThat(e.getCause(), is(expectedError)); + })); + } + + public void verifyAsyncOnCompleteIsStillDeliveredSynchronously() { + mockUpstream(); + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + verify(downstream, times(1)).onNext(any()); + processor.onComplete(); + verify(downstream, times(0)).onComplete(); + sub.getValue().request(1); + verify(downstream, times(1)).onComplete(); + } + + private void mockUpstream() { + Flow.Subscription upstream = mock(); + doAnswer(ans -> { + processor.onNext(output(randomIdentifier())); + return null; + }).when(upstream).request(anyLong()); + processor.onSubscribe(upstream); + } +}