Skip to content

Commit 6b2448a

Browse files
Fixing tool calls and message merging
1 parent 20af2ae commit 6b2448a

File tree

8 files changed

+409
-120
lines changed

8 files changed

+409
-120
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
3838

3939
public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> {
40+
public static final ElasticsearchStatusException CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION = new ElasticsearchStatusException(
41+
"The [chat_completion] task type only supports streaming, please try again with the _stream API",
42+
RestStatus.BAD_REQUEST
43+
);
4044
private final ModelRegistry modelRegistry;
4145
private final Client client;
4246

@@ -87,10 +91,7 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request,
8791

8892
try {
8993
if (request.isStreaming() == false) {
90-
throw new ElasticsearchStatusException(
91-
"The [chat_completion] task type only supports streaming, please try again with the _stream API",
92-
RestStatus.BAD_REQUEST
93-
);
94+
throw CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION;
9495
}
9596

9697
UnifiedCompletionAction.Request unifiedRequest;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,16 @@ public static String unsupportedTaskTypeForInference(Model model, EnumSet<TaskTy
10221022
);
10231023
}
10241024

1025+
public static ElasticsearchStatusException createUnsupportedTaskTypeStatusException(Model model, EnumSet<TaskType> supportedTaskTypes) {
1026+
var responseString = ServiceUtils.unsupportedTaskTypeForInference(model, supportedTaskTypes);
1027+
1028+
if (model.getTaskType() == TaskType.CHAT_COMPLETION) {
1029+
responseString = responseString + " " + useChatCompletionUrlMessage(model);
1030+
}
1031+
1032+
return new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST);
1033+
}
1034+
10251035
public static String useChatCompletionUrlMessage(Model model) {
10261036
return org.elasticsearch.common.Strings.format(
10271037
"The task type for the inference entity is %s, please use the _inference/%s/%s/%s URL.",

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.inference.TaskType;
3333
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
3434
import org.elasticsearch.rest.RestStatus;
35+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
3536
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
3637
import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker;
3738
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
@@ -63,6 +64,7 @@
6364
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
6465
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
6566
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
67+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUnsupportedTaskTypeStatusException;
6668
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
6769
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
6870
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -89,11 +91,16 @@ public class AmazonBedrockService extends SenderService {
8991

9092
private final Sender amazonBedrockSender;
9193

92-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
94+
// The task types exposed via the _inference/_services API
95+
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of(
9396
TaskType.TEXT_EMBEDDING,
9497
TaskType.COMPLETION,
9598
TaskType.CHAT_COMPLETION
9699
);
100+
/**
101+
* The task types that the {@link InferenceAction.Request} can accept.
102+
*/
103+
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
97104

98105
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
99106
InputType.INGEST,
@@ -154,6 +161,11 @@ protected void doInfer(
154161
TimeValue timeout,
155162
ActionListener<InferenceServiceResults> listener
156163
) {
164+
if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) {
165+
listener.onFailure(createUnsupportedTaskTypeStatusException(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES));
166+
return;
167+
}
168+
157169
if (model instanceof AmazonBedrockModel == false) {
158170
listener.onFailure(createInvalidModelException(model));
159171
return;
@@ -298,7 +310,7 @@ public InferenceServiceConfiguration getConfiguration() {
298310

299311
@Override
300312
public EnumSet<TaskType> supportedTaskTypes() {
301-
return supportedTaskTypes;
313+
return SUPPORTED_TASK_TYPES_FOR_SERVICES_API;
302314
}
303315

304316
private static AmazonBedrockModel createModel(
@@ -429,7 +441,9 @@ public static InferenceServiceConfiguration get() {
429441

430442
configurationMap.put(
431443
PROVIDER_FIELD,
432-
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The model provider for your deployment.")
444+
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription(
445+
"The model provider for your deployment."
446+
)
433447
.setLabel("Provider")
434448
.setRequired(true)
435449
.setSensitive(false)
@@ -440,7 +454,7 @@ public static InferenceServiceConfiguration get() {
440454

441455
configurationMap.put(
442456
MODEL_FIELD,
443-
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
457+
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription(
444458
"The base model ID or an ARN to a custom model based on a foundational model."
445459
)
446460
.setLabel("Model")
@@ -453,7 +467,7 @@ public static InferenceServiceConfiguration get() {
453467

454468
configurationMap.put(
455469
REGION_FIELD,
456-
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
470+
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription(
457471
"The region that your model or ARN is deployed in."
458472
)
459473
.setLabel("Region")
@@ -482,13 +496,13 @@ public static InferenceServiceConfiguration get() {
482496
configurationMap.putAll(
483497
RateLimitSettings.toSettingsConfigurationWithDescription(
484498
"By default, the amazonbedrock service sets the number of requests allowed per minute to 240.",
485-
supportedTaskTypes
499+
SUPPORTED_TASK_TYPES_FOR_SERVICES_API
486500
)
487501
);
488502

489503
return new InferenceServiceConfiguration.Builder().setService(NAME)
490504
.setName(SERVICE_NAME)
491-
.setTaskTypes(supportedTaskTypes)
505+
.setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)
492506
.setConfigurations(configurationMap)
493507
.build();
494508
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockChatCompletionExecutor.java

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
1414
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequest;
1515
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponseHandler;
16-
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.completion.AmazonBedrockChatCompletionResponseListener;
1716

1817
import java.util.function.Supplier;
1918

19+
import static org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy.CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION;
20+
2021
public class AmazonBedrockChatCompletionExecutor extends AmazonBedrockExecutor {
2122
private final AmazonBedrockChatCompletionRequest chatCompletionRequest;
2223

@@ -34,16 +35,13 @@ protected AmazonBedrockChatCompletionExecutor(
3435

3536
@Override
3637
protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) {
37-
if (chatCompletionRequest.isStreaming()) {
38-
var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient);
39-
inferenceResultsListener.onResponse(new StreamingUnifiedChatCompletionResults(publisher));
40-
} else {
41-
var completionResponseListener = new AmazonBedrockChatCompletionResponseListener(
42-
chatCompletionRequest,
43-
responseHandler,
44-
inferenceResultsListener
45-
);
46-
chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, completionResponseListener);
38+
// Chat completions only supports streaming
39+
if (chatCompletionRequest.isStreaming() == false) {
40+
inferenceResultsListener.onFailure(CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION);
41+
return;
4742
}
43+
44+
var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient);
45+
inferenceResultsListener.onResponse(new StreamingUnifiedChatCompletionResults(publisher));
4846
}
4947
}

0 commit comments

Comments
 (0)