Skip to content

Commit fb783b1

Browse files
committed
[ML] Allow InputType for Bedrock Titan
Semantic Search can now send InputType as part of the request to non-Cohere Bedrock models. Fix #126709
1 parent 8cb4493 commit fb783b1

File tree

2 files changed

+4
-67
lines changed

2 files changed

+4
-67
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
import java.util.HashMap;
5454
import java.util.List;
5555
import java.util.Map;
56-
import java.util.Objects;
5756
import java.util.Set;
5857

5958
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
@@ -81,15 +80,14 @@ public class AmazonBedrockService extends SenderService {
8180

8281
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
8382

84-
private static final AmazonBedrockProvider PROVIDER_WITH_TASK_TYPE = AmazonBedrockProvider.COHERE;
85-
8683
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
8784
InputType.INGEST,
8885
InputType.SEARCH,
8986
InputType.CLASSIFICATION,
9087
InputType.CLUSTERING,
9188
InputType.INTERNAL_INGEST,
92-
InputType.INTERNAL_SEARCH
89+
InputType.INTERNAL_SEARCH,
90+
InputType.UNSPECIFIED
9391
);
9492

9593
public AmazonBedrockService(
@@ -130,21 +128,8 @@ protected void doInfer(
130128

131129
@Override
132130
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
133-
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
134-
// inputType is only allowed when provider=cohere for text embeddings
135-
var provider = baseAmazonBedrockModel.provider();
136-
137-
if (Objects.equals(provider, PROVIDER_WITH_TASK_TYPE)) {
138-
// input type parameter allowed, so verify it is valid if specified
139-
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
140-
} else {
141-
// input type parameter not allowed so throw validation error if it is specified and not internal
142-
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(
143-
inputType,
144-
validationException,
145-
Strings.format("Invalid value [%s] received. [%s] is not allowed for provider [%s]", inputType, "input_type", provider)
146-
);
147-
}
131+
if (model instanceof AmazonBedrockModel) {
132+
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
148133
}
149134
}
150135

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -988,54 +988,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc
988988
verifyNoMoreInteractions(sender);
989989
}
990990

991-
public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForProviderThatDoesNotAcceptTaskType() throws IOException {
992-
var sender = mock(Sender.class);
993-
var factory = mock(HttpRequestSender.Factory.class);
994-
when(factory.createSender()).thenReturn(sender);
995-
996-
var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory(
997-
ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY),
998-
mockClusterServiceEmpty()
999-
);
1000-
var model = AmazonBedrockEmbeddingsModelTests.createModel(
1001-
"id",
1002-
"region",
1003-
"model",
1004-
AmazonBedrockProvider.AMAZONTITAN,
1005-
"access",
1006-
"secret"
1007-
);
1008-
1009-
try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) {
1010-
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
1011-
var thrownException = expectThrows(
1012-
ValidationException.class,
1013-
() -> service.infer(
1014-
model,
1015-
null,
1016-
null,
1017-
null,
1018-
List.of(""),
1019-
false,
1020-
new HashMap<>(),
1021-
InputType.INGEST,
1022-
InferenceAction.Request.DEFAULT_TIMEOUT,
1023-
listener
1024-
)
1025-
);
1026-
assertThat(
1027-
thrownException.getMessage(),
1028-
is("Validation Failed: 1: Invalid value [ingest] received. [input_type] is not allowed for provider [amazontitan];")
1029-
);
1030-
1031-
verify(factory, times(1)).createSender();
1032-
verify(sender, times(1)).start();
1033-
}
1034-
verify(sender, times(1)).close();
1035-
verifyNoMoreInteractions(factory);
1036-
verifyNoMoreInteractions(sender);
1037-
}
1038-
1039991
public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException {
1040992
var sender = mock(Sender.class);
1041993
var factory = mock(HttpRequestSender.Factory.class);

0 commit comments

Comments
 (0)