Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/8.18.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
initial_elasticsearch_8_18_6,8840008
transform_check_for_dangling_tasks,8840011
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/8.19.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
initial_elasticsearch_8_19_3,8841067
transform_check_for_dangling_tasks,8841070
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.0.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
initial_elasticsearch_9_0_6,9000015
transform_check_for_dangling_tasks,9000018
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.1.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
initial_elasticsearch_9_1_4,9112007
transform_check_for_dangling_tasks,9112009
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.2.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ml_inference_endpoint_cache,9157000
initial_9.2.0,9185000
1 change: 1 addition & 0 deletions server/src/main/resources/transport/upper_bounds/9.3.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
initial_9.2.0,9185000
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
containsInAnyOrder(
List.of(
"ai21",
"amazonbedrock",
"llama",
"deepseek",
"elastic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionEntityFactory;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionEntityFactory;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.completion.AmazonBedrockChatCompletionResponseHandler;

import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;

public class AmazonBedrockChatCompletionRequestManager extends AmazonBedrockRequestManager {
private static final Logger logger = LogManager.getLogger(AmazonBedrockChatCompletionRequestManager.class);
private final AmazonBedrockChatCompletionModel model;
Expand All @@ -46,9 +51,45 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class);
var inputs = chatCompletionInput.getInputs();
var stream = chatCompletionInput.stream();
switch (inferenceInputs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we create a new class that handles the UnifiedChatInput directly?

Once we've done that we can use that class directly in doUnifiedCompletionInfer.

case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener);
case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener);
default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
}
}

private void execute(
UnifiedChatInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var inputs = inferenceInputs.getRequest();
var stream = inferenceInputs.stream();
var requestEntity = AmazonBedrockUnifiedChatCompletionEntityFactory.createEntity(model, inputs);
var request = new AmazonBedrockUnifiedChatCompletionRequest(model, requestEntity, timeout, stream);
var responseHandler = new AmazonBedrockChatCompletionResponseHandler();

try {
requestSender.send(logger, request, hasRequestCompletedFunction, responseHandler, listener);
} catch (Exception e) {
var errorMessage = Strings.format(
"Failed to send [completion] request from inference entity id [%s]",
request.getInferenceEntityId()
);
logger.warn(errorMessage, e);
listener.onFailure(new ElasticsearchException(errorMessage, e));
}
}

private void execute(
ChatCompletionInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var inputs = inferenceInputs.getInputs();
var stream = inferenceInputs.stream();
var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, inputs);
var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream);
var responseHandler = new AmazonBedrockChatCompletionResponseHandler();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public final class AmazonBedrockProviderCapabilities {

public static boolean providerAllowsTaskType(AmazonBedrockProvider provider, TaskType taskType) {
switch (taskType) {
case COMPLETION -> {
case COMPLETION, CHAT_COMPLETION -> {
return chatCompletionProviders.contains(provider);
}
case TEXT_EMBEDDING -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD;
Expand All @@ -77,10 +76,15 @@
public class AmazonBedrockService extends SenderService {
public static final String NAME = "amazonbedrock";
private static final String SERVICE_NAME = "Amazon Bedrock";
public static final String COMPLETION_ERROR_PREFIX = "Amazon Bedrock chat completion";

private final Sender amazonBedrockSender;

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);

private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
InputType.INGEST,
Expand Down Expand Up @@ -118,7 +122,7 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
infer(model, inputs, null, timeout, listener);
}

@Override
Expand All @@ -128,6 +132,16 @@ protected void doInfer(
Map<String, Object> taskSettings,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
infer(model, inputs, taskSettings, timeout, listener);
}

private void infer(
Model model,
InferenceInputs inputs,
Map<String, Object> taskSettings,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout);
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
Expand Down Expand Up @@ -303,6 +317,19 @@ private static AmazonBedrockModel createModel(
checkTaskSettingsForTextEmbeddingModel(model);
return model;
}
case CHAT_COMPLETION -> {
var model = new AmazonBedrockChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);
checkProviderForTask(TaskType.CHAT_COMPLETION, model.provider());
return model;
}
case COMPLETION -> {
var model = new AmazonBedrockChatCompletionModel(
inferenceEntityId,
Expand All @@ -328,7 +355,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public Set<TaskType> supportedStreamingTasks() {
return COMPLETION_ONLY;
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;

import java.time.Instant;
import java.util.concurrent.Flow;
Expand All @@ -26,6 +27,9 @@ public interface AmazonBedrockClient {
Flow.Publisher<? extends InferenceServiceResults.Result> converseStream(ConverseStreamRequest converseStreamRequest)
throws ElasticsearchException;

Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> converseUnifiedStream(ConverseStreamRequest request)
throws ElasticsearchException;

void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener<InvokeModelResponse> responseListener)
throws ElasticsearchException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponseHandler;

Expand Down Expand Up @@ -83,6 +84,16 @@ protected AmazonBedrockExecutor createExecutor(
clientCache
);
}
case CHAT_COMPLETION -> {
return new AmazonBedrockUnifiedChatCompletionExecutor(
(AmazonBedrockUnifiedChatCompletionRequest) awsRequest,
awsResponse,
logger,
hasRequestTimedOutFunction,
listener,
clientCache
);
}
case TEXT_EMBEDDING -> {
return new AmazonBedrockEmbeddingsExecutor(
(AmazonBedrockEmbeddingsRequest) awsRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel;
import org.reactivestreams.FlowAdapters;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -101,6 +102,20 @@ public Flow.Publisher<? extends InferenceServiceResults.Result> converseStream(C
return awsResponseProcessor;
}

@Override
public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> converseUnifiedStream(ConverseStreamRequest request)
throws ElasticsearchException {
var awsResponseProcessor = new AmazonBedrockUnifiedStreamingChatProcessor(threadPool);
internalClient.converseStream(
request,
ConverseStreamResponseHandler.builder().subscriber(() -> FlowAdapters.toSubscriber(awsResponseProcessor)).build()
).exceptionally(e -> {
awsResponseProcessor.onError(e);
return null; // Void
});
return awsResponseProcessor;
}

private void onFailure(ActionListener<?> listener, Throwable t, String method) {
ExceptionsHelper.maybeDieOnAnotherThread(t);
var unwrappedException = t;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.amazonbedrock.client;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponseHandler;

import java.util.function.Supplier;

public class AmazonBedrockUnifiedChatCompletionExecutor extends AmazonBedrockExecutor {
private final AmazonBedrockUnifiedChatCompletionRequest chatCompletionRequest;

protected AmazonBedrockUnifiedChatCompletionExecutor(
AmazonBedrockUnifiedChatCompletionRequest request,
AmazonBedrockResponseHandler responseHandler,
Logger logger,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> inferenceResultsListener,
AmazonBedrockClientCache clientCache
) {
super(request, responseHandler, logger, hasRequestCompletedFunction, inferenceResultsListener, clientCache);
this.chatCompletionRequest = request;
}

@Override
protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) {
if (chatCompletionRequest.isStreaming()) {
var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient);
inferenceResultsListener.onResponse(new StreamingUnifiedChatCompletionResults(publisher));
}
}
}
Loading