Skip to content

Commit 3a3e946

Browse files
committed
service tests
1 parent ea4ad64 commit 3a3e946

File tree

35 files changed

+445
-543
lines changed

35 files changed

+445
-543
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ public void chunkedInfer(
120120
ActionListener<List<ChunkedInference>> listener
121121
) {
122122
init();
123+
124+
ValidationException validationException = new ValidationException();
125+
validateInputType(inputType, model, validationException);
126+
if (validationException.validationErrors().isEmpty() == false) {
127+
throw validationException;
128+
}
129+
123130
// a non-null query is not supported and is dropped by all providers
124131
doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener);
125132
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,5 +813,28 @@ public static void validateInputTypeIsUnspecifiedOrInternal(InputType inputType,
813813
}
814814
}
815815

816+
public static void validateInputTypeIsUnspecifiedOrInternal(
817+
InputType inputType,
818+
ValidationException validationException,
819+
String customErrorMessage
820+
) {
821+
if (inputType != null && inputType != InputType.UNSPECIFIED && VALID_INTERNAL_INPUT_TYPE_VALUES.contains(inputType) == false) {
822+
validationException.addValidationError(customErrorMessage);
823+
}
824+
}
825+
826+
public static void validateInputTypeAgainstAllowlist(
827+
InputType inputType,
828+
EnumSet<InputType> allowedInputTypes,
829+
String name,
830+
ValidationException validationException
831+
) {
832+
if (inputType != null && inputType != InputType.UNSPECIFIED && allowedInputTypes.contains(inputType) == false) {
833+
validationException.addValidationError(
834+
org.elasticsearch.common.Strings.format("Input type [%s] is not supported for [%s]", inputType, name)
835+
);
836+
}
837+
}
838+
816839
private ServiceUtils() {}
817840
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14-
import org.elasticsearch.common.Strings;
1514
import org.elasticsearch.common.ValidationException;
1615
import org.elasticsearch.common.util.LazyInitializable;
1716
import org.elasticsearch.core.Nullable;
@@ -298,9 +297,7 @@ public void doInfer(
298297

299298
@Override
300299
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
301-
if (VALID_INPUT_TYPE_VALUES.contains(inputType) == false) {
302-
validationException.addValidationError(Strings.format("Input type [%s] is not supported for [%s]", inputType, SERVICE_NAME));
303-
}
300+
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
304301
}
305302

306303
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
import java.util.Set;
5858

5959
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
60-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.VALID_INTERNAL_INPUT_TYPE_VALUES;
6160
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
6261
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
6362
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
@@ -132,30 +131,19 @@ protected void doInfer(
132131
@Override
133132
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
134133
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
135-
// InputType is only respected when provider=cohere for text embeddings
134+
// inputType is only allowed when provider=cohere for text embeddings
136135
var provider = baseAmazonBedrockModel.provider();
137136

138-
if (Objects.equals(provider, PROVIDER_WITH_TASK_TYPE) == false) {
139-
// this model does not accept input type parameter so throw validation error if it is specified and not internal
140-
if (inputType != null
141-
&& inputType != InputType.UNSPECIFIED
142-
&& VALID_INTERNAL_INPUT_TYPE_VALUES.contains(inputType) == false) {
143-
validationException.addValidationError(
144-
Strings.format(
145-
"Invalid value [%s] received. [%s] is not allowed for provider [%s]",
146-
inputType,
147-
"input_type",
148-
provider
149-
)
150-
);
151-
}
137+
if (Objects.equals(provider, PROVIDER_WITH_TASK_TYPE)) {
138+
// input type parameter allowed, so verify it is valid if specified
139+
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
152140
} else {
153-
// this model does accept input type parameter, so verify it is valid if specified
154-
if (inputType != null && inputType != InputType.UNSPECIFIED && VALID_INPUT_TYPE_VALUES.contains(inputType) == false) {
155-
validationException.addValidationError(
156-
Strings.format("Input type [%s] is not supported for [%s]", inputType, SERVICE_NAME)
157-
);
158-
}
141+
// input type parameter not allowed so throw validation error if it is specified and not internal
142+
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(
143+
inputType,
144+
validationException,
145+
Strings.format("Invalid value [%s] received. [%s] is not allowed for provider [%s]", inputType, "input_type", provider)
146+
);
159147
}
160148
}
161149
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ public class AzureAiStudioService extends SenderService {
7777
private static final String SERVICE_NAME = "Azure AI Studio";
7878
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
7979

80+
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
81+
InputType.INGEST,
82+
InputType.SEARCH,
83+
InputType.INTERNAL_INGEST,
84+
InputType.INTERNAL_SEARCH
85+
);
86+
8087
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
8188
super(factory, serviceComponents);
8289
}
@@ -110,7 +117,9 @@ protected void doInfer(
110117
}
111118

112119
@Override
113-
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {}
120+
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
121+
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
122+
}
114123

115124
@Override
116125
protected void doChunkedInfer(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14-
import org.elasticsearch.common.Strings;
1514
import org.elasticsearch.common.ValidationException;
1615
import org.elasticsearch.common.util.LazyInitializable;
1716
import org.elasticsearch.core.Nullable;
@@ -279,9 +278,7 @@ public void doInfer(
279278

280279
@Override
281280
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
282-
if (VALID_INPUT_TYPE_VALUES.contains(inputType) == false) {
283-
validationException.addValidationError(Strings.format("Input type [%s] is not supported for [%s]", inputType, SERVICE_NAME));
284-
}
281+
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
285282
}
286283

287284
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959

6060
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
6161
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
62-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.VALID_INTERNAL_INPUT_TYPE_VALUES;
6362
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
6463
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
6564
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
@@ -320,25 +319,19 @@ protected void doInfer(
320319
@Override
321320
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
322321
if (model instanceof GoogleAiStudioEmbeddingsModel embeddingsModel) {
323-
// InputType is only respected when model=embedding-001 https://ai.google.dev/api/embeddings?authuser=5#EmbedContentRequest
322+
// inputType is only allowed when model=embedding-001 https://ai.google.dev/api/embeddings?authuser=5#EmbedContentRequest
324323
var modelId = embeddingsModel.getServiceSettings().modelId();
325324

326-
if (Objects.equals(modelId, MODEL_ID_WITH_TASK_TYPE) == false) {
327-
// this model does not accept input type parameter so throw validation error if it is specified and not internal
328-
if (inputType != null
329-
&& inputType != InputType.UNSPECIFIED
330-
&& VALID_INTERNAL_INPUT_TYPE_VALUES.contains(inputType) == false) {
331-
// throw validation exception if ingest type is specified
332-
validationException.addValidationError(
333-
Strings.format("Invalid value [%s] received. [%s] is not allowed for model [%s]", inputType, "input_type", modelId)
334-
);
335-
}
325+
if (Objects.equals(modelId, MODEL_ID_WITH_TASK_TYPE)) {
326+
// input type parameter allowed, so verify it is valid if specified
327+
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
336328
} else {
337-
if (inputType != null && inputType != InputType.UNSPECIFIED && VALID_INPUT_TYPE_VALUES.contains(inputType) == false) {
338-
validationException.addValidationError(
339-
Strings.format("Input type [%s] is not supported for [%s]", inputType, SERVICE_NAME)
340-
);
341-
}
329+
// input type parameter not allowed so throw validation error if it is specified and not internal
330+
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(
331+
inputType,
332+
validationException,
333+
Strings.format("Invalid value [%s] received. [%s] is not allowed for model [%s]", inputType, "input_type", modelId)
334+
);
342335
}
343336
}
344337
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import static org.elasticsearch.core.Strings.format;
2828

2929
public class GoogleAiStudioEmbeddingsModel extends GoogleAiStudioModel {
30+
3031
private URI uri;
3132

3233
public GoogleAiStudioEmbeddingsModel(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14-
import org.elasticsearch.common.Strings;
1514
import org.elasticsearch.common.ValidationException;
1615
import org.elasticsearch.common.util.LazyInitializable;
1716
import org.elasticsearch.core.Nullable;
@@ -218,9 +217,7 @@ protected void doInfer(
218217

219218
@Override
220219
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
221-
if (VALID_INPUT_TYPE_VALUES.contains(inputType) == false) {
222-
validationException.addValidationError(Strings.format("Input type [%s] is not supported for [%s]", inputType, SERVICE_NAME));
223-
}
220+
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
224221
}
225222

226223
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.Map;
2323

2424
public class HuggingFaceEmbeddingsModel extends HuggingFaceModel {
25-
2625
public HuggingFaceEmbeddingsModel(
2726
String inferenceEntityId,
2827
TaskType taskType,

0 commit comments

Comments
 (0)