Skip to content

Commit d870f42

Browse files
authored
[ML] Allow InputType for Bedrock Titan (#127021)
Semantic Search can now send InputType as part of the request to non-Cohere Bedrock models. Fix #126709
1 parent f461f90 commit d870f42

File tree

2 files changed

+26
-45
lines changed

2 files changed

+26
-45
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
import java.util.HashMap;
5454
import java.util.List;
5555
import java.util.Map;
56-
import java.util.Objects;
5756
import java.util.Set;
5857

5958
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
@@ -81,15 +80,14 @@ public class AmazonBedrockService extends SenderService {
8180

8281
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
8382

84-
private static final AmazonBedrockProvider PROVIDER_WITH_TASK_TYPE = AmazonBedrockProvider.COHERE;
85-
8683
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
8784
InputType.INGEST,
8885
InputType.SEARCH,
8986
InputType.CLASSIFICATION,
9087
InputType.CLUSTERING,
9188
InputType.INTERNAL_INGEST,
92-
InputType.INTERNAL_SEARCH
89+
InputType.INTERNAL_SEARCH,
90+
InputType.UNSPECIFIED
9391
);
9492

9593
public AmazonBedrockService(
@@ -130,21 +128,8 @@ protected void doInfer(
130128

131129
@Override
132130
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
133-
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
134-
// inputType is only allowed when provider=cohere for text embeddings
135-
var provider = baseAmazonBedrockModel.provider();
136-
137-
if (Objects.equals(provider, PROVIDER_WITH_TASK_TYPE)) {
138-
// input type parameter allowed, so verify it is valid if specified
139-
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
140-
} else {
141-
// input type parameter not allowed so throw validation error if it is specified and not internal
142-
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(
143-
inputType,
144-
validationException,
145-
Strings.format("Invalid value [%s] received. [%s] is not allowed for provider [%s]", inputType, "input_type", provider)
146-
);
147-
}
131+
if (model instanceof AmazonBedrockModel) {
132+
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
148133
}
149134
}
150135

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc
988988
verifyNoMoreInteractions(sender);
989989
}
990990

991-
public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForProviderThatDoesNotAcceptTaskType() throws IOException {
991+
public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException {
992992
var sender = mock(Sender.class);
993993
var factory = mock(HttpRequestSender.Factory.class);
994994
when(factory.createSender()).thenReturn(sender);
@@ -1006,37 +1006,33 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForProviderTh
10061006
"secret"
10071007
);
10081008

1009-
try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) {
1009+
try (
1010+
var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool));
1011+
var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()
1012+
) {
1013+
var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })));
1014+
requestSender.enqueue(results);
10101015
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
1011-
var thrownException = expectThrows(
1012-
ValidationException.class,
1013-
() -> service.infer(
1014-
model,
1015-
null,
1016-
null,
1017-
null,
1018-
List.of(""),
1019-
false,
1020-
new HashMap<>(),
1021-
InputType.INGEST,
1022-
InferenceAction.Request.DEFAULT_TIMEOUT,
1023-
listener
1024-
)
1025-
);
1026-
assertThat(
1027-
thrownException.getMessage(),
1028-
is("Validation Failed: 1: Invalid value [ingest] received. [input_type] is not allowed for provider [amazontitan];")
1016+
service.infer(
1017+
model,
1018+
null,
1019+
null,
1020+
null,
1021+
List.of("abc"),
1022+
false,
1023+
new HashMap<>(),
1024+
InputType.INGEST,
1025+
InferenceAction.Request.DEFAULT_TIMEOUT,
1026+
listener
10291027
);
10301028

1031-
verify(factory, times(1)).createSender();
1032-
verify(sender, times(1)).start();
1029+
var result = listener.actionGet(TIMEOUT);
1030+
1031+
assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.678F }))));
10331032
}
1034-
verify(sender, times(1)).close();
1035-
verifyNoMoreInteractions(factory);
1036-
verifyNoMoreInteractions(sender);
10371033
}
10381034

1039-
public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException {
1035+
public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException {
10401036
var sender = mock(Sender.class);
10411037
var factory = mock(HttpRequestSender.Factory.class);
10421038
when(factory.createSender()).thenReturn(sender);

0 commit comments

Comments
 (0)