From 9e97819ab83cf48facef6fe35d660cb4676a20cf Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 12 Jun 2025 17:22:01 -0400 Subject: [PATCH 01/13] Making progress on different request parameters --- .../elasticsearch/inference/InputType.java | 64 +++++ .../inference/services/ServiceUtils.java | 12 + .../services/custom/CustomRequestManager.java | 13 +- .../custom/CustomServiceSettings.java | 45 ++- .../services/custom/InputTypeTranslator.java | 118 ++++++++ .../custom/request/CompletionParameters.java | 38 +++ .../custom/request/CustomRequest.java | 13 +- .../custom/request/RequestParameters.java | 38 +++ .../custom/request/RerankParameters.java | 54 ++++ .../custom/InputTypeTranslatorTests.java | 269 ++++++++++++++++++ 10 files changed, 649 insertions(+), 15 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/InputType.java b/server/src/main/java/org/elasticsearch/inference/InputType.java index 075d4af2968b3..46d148d04b6a4 100644 --- a/server/src/main/java/org/elasticsearch/inference/InputType.java +++ b/server/src/main/java/org/elasticsearch/inference/InputType.java @@ -10,8 +10,12 @@ package org.elasticsearch.inference; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import java.util.EnumSet; +import java.util.HashMap; import java.util.Locale; +import java.util.Map; import static org.elasticsearch.core.Strings.format; @@ -29,6 +33,13 @@ public enum InputType { INTERNAL_SEARCH, INTERNAL_INGEST; + private static final EnumSet SUPPORTED_REQUEST_VALUES = EnumSet.of( + InputType.CLASSIFICATION, + InputType.CLUSTERING, + InputType.INGEST, + InputType.SEARCH + ); + @Override public String toString() { return name().toLowerCase(Locale.ROOT); @@ -57,4 +68,57 @@ public static boolean isSpecified(InputType inputType) { public static String invalidInputTypeMessage(InputType inputType) { return Strings.format("received invalid input type value [%s]", inputType.toString()); } + + /** + * Ensures that a map used for translating input types is valid. The keys of the map are the external representation, + * and the values correspond to the values in this class. + * Throws a {@link ValidationException} if any value is not a valid InputType. + * + * @param inputTypeTranslation the map of input type translations to validate + * @param validationException a ValidationException to which errors will be added + */ + public static Map validateInputTypeTranslationValues( + Map inputTypeTranslation, + ValidationException validationException + ) { + if (inputTypeTranslation == null || inputTypeTranslation.isEmpty()) { + return Map.of(); + } + + var translationMap = new HashMap(); + + for (var entry : inputTypeTranslation.entrySet()) { + var key = entry.getKey(); + var value = entry.getValue(); + + if (value instanceof String == false || Strings.isNullOrEmpty((String) value)) { + validationException.addValidationError( + Strings.format( + "Input type translation value for key [%s] must be a String that is not null and not empty, received: [%s].", + key, + value.getClass().getSimpleName() + ) + ); + + throw validationException; + } + + try { + var inputTypeKey = InputType.fromRestString(key); + translationMap.put(inputTypeKey, (String) value); + } catch (Exception e) { + validationException.addValidationError( + Strings.format( + "Invalid input type translation for key: [%s], is not a valid value. Must be one of %s", + key, + EnumSet.of(InputType.CLASSIFICATION, InputType.CLUSTERING, InputType.INGEST, InputType.SEARCH) + ) + ); + + throw validationException; + } + } + + return translationMap; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index b12f5989e55ce..7d3ad374fdfa3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -400,6 +400,18 @@ public static String extractRequiredString( return requiredField; } + public static String extractOptionalEmptyString(Map map, String settingName, ValidationException validationException) { + int initialValidationErrorCount = validationException.validationErrors().size(); + String optionalField = ServiceUtils.removeAsType(map, settingName, String.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + // new validation error occurred + return null; + } + + return optionalField; + } + public static String extractOptionalString( Map map, String settingName, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java index a112e7db26fe3..1095b29a64306 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -23,7 +23,10 @@ import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters; import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; +import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters; +import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseEntity; import java.util.List; @@ -67,13 +70,11 @@ public void execute( ) { String query; List input; + RequestParameters requestParameters; if (inferenceInputs instanceof QueryAndDocsInputs) { - QueryAndDocsInputs queryAndDocsInputs = QueryAndDocsInputs.of(inferenceInputs); - query = queryAndDocsInputs.getQuery(); - input = queryAndDocsInputs.getChunks(); + requestParameters = RerankParameters.of(QueryAndDocsInputs.of(inferenceInputs)); } else if (inferenceInputs instanceof ChatCompletionInput chatInputs) { - query = null; - input = chatInputs.getInputs(); + requestParameters = CompletionParameters.of(chatInputs); } else if (inferenceInputs instanceof EmbeddingsInput) { EmbeddingsInput embeddingsInput = EmbeddingsInput.of(inferenceInputs); query = null; @@ -89,7 +90,7 @@ public void execute( } try { - var request = new CustomRequest(query, input, model); + var request = new CustomRequest(requestParameters, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, handler, hasRequestCompletedFunction, listener)); } catch (Exception e) { // Intentionally not logging this exception because it could contain sensitive information from the CustomRequest construction diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index bcad5554e5915..a7c02ba37456c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -125,6 +125,8 @@ public static CustomServiceSettings fromMap( context ); + var inputTypeTranslator = InputTypeTranslator.fromMap(map, validationException, CustomService.NAME); + if (requestBodyMap == null || responseParserMap == null || jsonParserMap == null || errorParserMap == null) { throw validationException; } @@ -146,7 +148,8 @@ public static CustomServiceSettings fromMap( requestContentString, responseJsonParser, rateLimitSettings, - errorParser + errorParser, + inputTypeTranslator ); } @@ -219,6 +222,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private final CustomResponseParser responseJsonParser; private final RateLimitSettings rateLimitSettings; private final ErrorResponseParser errorParser; + private final InputTypeTranslator inputTypeTranslator; public CustomServiceSettings( TextEmbeddingSettings textEmbeddingSettings, @@ -229,6 +233,30 @@ public CustomServiceSettings( CustomResponseParser responseJsonParser, @Nullable RateLimitSettings rateLimitSettings, ErrorResponseParser errorParser + ) { + this( + textEmbeddingSettings, + url, + headers, + queryParameters, + requestContentString, + responseJsonParser, + rateLimitSettings, + errorParser, + InputTypeTranslator.EMPTY_TRANSLATOR + ); + } + + public CustomServiceSettings( + TextEmbeddingSettings textEmbeddingSettings, + String url, + @Nullable Map headers, + @Nullable QueryParameters queryParameters, + String requestContentString, + CustomResponseParser responseJsonParser, + @Nullable RateLimitSettings rateLimitSettings, + ErrorResponseParser errorParser, + InputTypeTranslator inputTypeTranslator ) { this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings); this.url = Objects.requireNonNull(url); @@ -238,6 +266,7 @@ public CustomServiceSettings( this.responseJsonParser = Objects.requireNonNull(responseJsonParser); this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); this.errorParser = Objects.requireNonNull(errorParser); + this.inputTypeTranslator = Objects.requireNonNull(inputTypeTranslator); } public CustomServiceSettings(StreamInput in) throws IOException { @@ -249,6 +278,7 @@ public CustomServiceSettings(StreamInput in) throws IOException { responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); rateLimitSettings = new RateLimitSettings(in); errorParser = new ErrorResponseParser(in); + inputTypeTranslator = new InputTypeTranslator(in); } @Override @@ -296,6 +326,10 @@ public CustomResponseParser getResponseJsonParser() { return responseJsonParser; } + public InputTypeTranslator getInputTypeTranslator() { + return inputTypeTranslator; + } + public ErrorResponseParser getErrorParser() { return errorParser; } @@ -348,6 +382,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder } builder.endObject(); + inputTypeTranslator.toXContent(builder, params); + rateLimitSettings.toXContent(builder, params); return builder; @@ -373,6 +409,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(responseJsonParser); rateLimitSettings.writeTo(out); errorParser.writeTo(out); + inputTypeTranslator.writeTo(out); } @Override @@ -387,7 +424,8 @@ public boolean equals(Object o) { && Objects.equals(requestContentString, that.requestContentString) && Objects.equals(responseJsonParser, that.responseJsonParser) && Objects.equals(rateLimitSettings, that.rateLimitSettings) - && Objects.equals(errorParser, that.errorParser); + && Objects.equals(errorParser, that.errorParser) + && Objects.equals(inputTypeTranslator, that.inputTypeTranslator); } @Override @@ -400,7 +438,8 @@ public int hashCode() { requestContentString, responseJsonParser, rateLimitSettings, - errorParser + errorParser, + inputTypeTranslator ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java new file mode 100644 index 0000000000000..5a3400bde0fba --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEmptyString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class InputTypeTranslator implements ToXContentFragment, Writeable { + public static final String INPUT_TYPE_TRANSLATOR = "input_type"; + public static final String TRANSLATION = "translation"; + public static final String DEFAULT = "default"; + public static final InputTypeTranslator EMPTY_TRANSLATOR = new InputTypeTranslator(null, null); + + public static InputTypeTranslator fromMap(Map map, ValidationException validationException, String serviceName) { + if (map == null || map.isEmpty()) { + return EMPTY_TRANSLATOR; + } + + Map inputTypeTranslation = Objects.requireNonNullElse( + extractOptionalMap(map, INPUT_TYPE_TRANSLATOR, ModelConfigurations.SERVICE_SETTINGS, validationException), + new HashMap<>(Map.of()) + ); + + Map translationMap = extractOptionalMap( + inputTypeTranslation, + TRANSLATION, + INPUT_TYPE_TRANSLATOR, + validationException + ); + + var validatedTranslation = InputType.validateInputTypeTranslationValues(translationMap, validationException); + + var defaultValue = extractOptionalEmptyString(inputTypeTranslation, DEFAULT, validationException); + + throwIfNotEmptyMap(inputTypeTranslation, INPUT_TYPE_TRANSLATOR, "input_type_translator"); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new InputTypeTranslator(validatedTranslation, defaultValue); + } + + private final Map inputTypeTranslation; + private final String defaultValue; + + public InputTypeTranslator(@Nullable Map inputTypeTranslation, @Nullable String defaultValue) { + this.inputTypeTranslation = Objects.requireNonNullElse(inputTypeTranslation, Map.of()); + this.defaultValue = Objects.requireNonNullElse(defaultValue, ""); + } + + public InputTypeTranslator(StreamInput in) throws IOException { + this.inputTypeTranslation = in.readImmutableMap(keyReader -> keyReader.readEnum(InputType.class), StreamInput::readString); + this.defaultValue = in.readString(); + } + + public Map getInputTypeTranslation() { + return inputTypeTranslation; + } + + public String getDefaultValue() { + return defaultValue; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(INPUT_TYPE_TRANSLATOR); + { + builder.startObject(TRANSLATION); + for (var entry : inputTypeTranslation.entrySet()) { + builder.field(entry.getKey().toString(), entry.getValue()); + } + builder.endObject(); + builder.field(DEFAULT, defaultValue); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(inputTypeTranslation, StreamOutput::writeEnum, StreamOutput::writeString); + out.writeString(defaultValue); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + InputTypeTranslator that = (InputTypeTranslator) o; + return Objects.equals(inputTypeTranslation, that.inputTypeTranslation) && Objects.equals(defaultValue, that.defaultValue); + } + + @Override + public int hashCode() { + return Objects.hash(inputTypeTranslation, defaultValue); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java new file mode 100644 index 0000000000000..d12a95187b4db --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; + +public class CompletionParameters extends RequestParameters { + + public static CompletionParameters of(ChatCompletionInput completionInput) { + return new CompletionParameters(Objects.requireNonNull(completionInput)); + } + + private CompletionParameters(ChatCompletionInput completionInput) { + super(completionInput.getInputs()); + } + + @Override + public Map jsonParameters() { + String jsonRep = toJson(inputs, INPUT); + + if (inputs.isEmpty() == false) { + jsonRep = toJson(inputs.get(0), INPUT); + } + + return Map.of(INPUT, jsonRep); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java index 0a50b08163260..c3f4d5bc64ffe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java @@ -44,7 +44,7 @@ public class CustomRequest implements Request { private final ValidatingSubstitutor stringPlaceholderReplacer; private final CustomModel model; - public CustomRequest(String query, List input, CustomModel model) { + public CustomRequest(RequestParameters requestParams, CustomModel model) { this.model = Objects.requireNonNull(model); var stringOnlyParams = new HashMap(); @@ -55,11 +55,13 @@ public CustomRequest(String query, List input, CustomModel model) { addJsonStringParams(jsonParams, model.getSecretSettings().getSecretParameters()); addJsonStringParams(jsonParams, model.getTaskSettings().getParameters()); - if (query != null) { - jsonParams.put(QUERY, toJson(query, QUERY)); - } + // if (query != null) { + // jsonParams.put(QUERY, toJson(query, QUERY)); + // } + + // addInputJsonParam(jsonParams, input, model.getTaskType()); - addInputJsonParam(jsonParams, input, model.getTaskType()); + jsonParams.putAll(requestParams.jsonParameters()); jsonPlaceholderReplacer = new ValidatingSubstitutor(jsonParams, "${", "}"); stringPlaceholderReplacer = new ValidatingSubstitutor(stringOnlyParams, "${", "}"); @@ -107,7 +109,6 @@ private URI buildUri() { } catch (URISyntaxException e) { throw new IllegalStateException(Strings.format("Failed to build URI, error: %s", e.getMessage()), e); } - } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java new file mode 100644 index 0000000000000..a3c877f61d61f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; + +public abstract class RequestParameters { + + public static final String INPUT = "input"; + + protected final List inputs; + + public RequestParameters(List inputs) { + this.inputs = Objects.requireNonNull(inputs); + } + + Map jsonParameters() { + var additionalParameters = childParameters(); + var totalParameters = new HashMap<>(additionalParameters); + totalParameters.put(INPUT, toJson(inputs, INPUT)); + + return totalParameters; + } + + protected Map childParameters() { + return Map.of(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java new file mode 100644 index 0000000000000..b4ae8d04e3f88 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; + +public class RerankParameters extends RequestParameters { + private static final String QUERY = "query"; + + public static RerankParameters of(QueryAndDocsInputs queryAndDocsInputs) { + Objects.requireNonNull(queryAndDocsInputs); + + return new RerankParameters(queryAndDocsInputs); + } + + private final QueryAndDocsInputs queryAndDocsInputs; + + private RerankParameters(QueryAndDocsInputs queryAndDocsInputs) { + super(queryAndDocsInputs.getChunks()); + this.queryAndDocsInputs = queryAndDocsInputs; + } + + @Override + protected Map childParameters() { + var additionalParameters = new HashMap(); + additionalParameters.put(QUERY, queryAndDocsInputs.getQuery()); + if (queryAndDocsInputs.getTopN() != null) { + additionalParameters.put( + InferenceAction.Request.TOP_N.getPreferredName(), + toJson(queryAndDocsInputs.getTopN(), InferenceAction.Request.TOP_N.getPreferredName()) + ); + } + + if (queryAndDocsInputs.getReturnDocuments() != null) { + additionalParameters.put( + InferenceAction.Request.RETURN_DOCUMENTS.getPreferredName(), + toJson(queryAndDocsInputs.getReturnDocuments(), InferenceAction.Request.RETURN_DOCUMENTS.getPreferredName()) + ); + } + return additionalParameters; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java new file mode 100644 index 0000000000000..f005f62fd6969 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java @@ -0,0 +1,269 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.core.Tuple.tuple; +import static org.hamcrest.Matchers.is; + +public class InputTypeTranslatorTests extends AbstractBWCWireSerializationTestCase { + public static InputTypeTranslator createRandom() { + Map translation = randomBoolean() + ? randomMap( + 0, + 5, + () -> tuple( + randomFrom(List.of(InputType.CLASSIFICATION, InputType.CLUSTERING, InputType.INGEST, InputType.SEARCH)), + randomAlphaOfLength(5) + ) + ) + : Map.of(); + return new InputTypeTranslator(translation, randomAlphaOfLength(5)); + } + + public void testFromMap() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>( + Map.of( + InputTypeTranslator.TRANSLATION, + new HashMap<>( + Map.of( + "CLASSIFICATION", + "test_value", + "CLUSTERING", + "test_value_2", + "INGEST", + "test_value_3", + "SEARCH", + "test_value_4" + ) + ), + InputTypeTranslator.DEFAULT, + "default_value" + ) + ) + ) + ); + + assertThat( + InputTypeTranslator.fromMap(settings, new ValidationException(), "name"), + is( + new InputTypeTranslator( + Map.of( + InputType.CLASSIFICATION, + "test_value", + InputType.CLUSTERING, + "test_value_2", + InputType.INGEST, + "test_value_3", + InputType.SEARCH, + "test_value_4" + ), + "default_value" + ) + ) + ); + } + + public void testFromMap_Null_EmptyMap_Returns_EmptySettings() { + assertThat(InputTypeTranslator.fromMap(null, null, null), is(InputTypeTranslator.EMPTY_TRANSLATOR)); + assertThat(InputTypeTranslator.fromMap(Map.of(), null, null), is(InputTypeTranslator.EMPTY_TRANSLATOR)); + } + + public void testFromMap_Throws_IfValueIsNotAString() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>(Map.of(InputTypeTranslator.TRANSLATION, new HashMap<>(Map.of("CLASSIFICATION", 12345)))) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> InputTypeTranslator.fromMap(settings, new ValidationException(), "name") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [CLASSIFICATION] " + + "must be a String that is not null and not empty, received: [Integer].;" + ) + ); + } + + public void testFromMap_Throws_IfValueIsEmptyString() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>(Map.of(InputTypeTranslator.TRANSLATION, new HashMap<>(Map.of("CLASSIFICATION", "")))) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> InputTypeTranslator.fromMap(settings, new ValidationException(), "name") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [CLASSIFICATION] " + + "must be a String that is not null and not empty, received: [String].;" + ) + ); + } + + public void testFromMap_DoesNotThrow_ForAnEmptyDefaultValue() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>( + Map.of( + InputTypeTranslator.TRANSLATION, + new HashMap<>(Map.of("CLASSIFICATION", "value")), + InputTypeTranslator.DEFAULT, + "" + ) + ) + ) + ); + + var translator = InputTypeTranslator.fromMap(settings, new ValidationException(), "name"); + + assertThat(translator, is(new InputTypeTranslator(Map.of(InputType.CLASSIFICATION, "value"), ""))); + } + + public void testFromMap_Throws_IfKeyIsInvalid() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>( + Map.of( + InputTypeTranslator.TRANSLATION, + new HashMap<>(Map.of("CLASSIFICATION", "test_value", "invalid_key", "another_value")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> InputTypeTranslator.fromMap(settings, new ValidationException(), "name") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Invalid input type translation for key: [invalid_key]" + + ", is not a valid value. Must be one of [ingest, search, classification, clustering];" + ) + ); + } + + public void testFromMap_DefaultsToEmptyMap_WhenField_DoesNotExist() { + var map = new HashMap(Map.of("key", new HashMap<>(Map.of("test_key", "test_value")))); + + assertThat(InputTypeTranslator.fromMap(map, new ValidationException(), "name"), is(new InputTypeTranslator(Map.of(), null))); + } + + public void testXContent() throws IOException { + var entity = new InputTypeTranslator(Map.of(InputType.CLASSIFICATION, "test_value"), "default"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "input_type": { + "translation": { + "classification": "test_value" + }, + "default": "default" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_EmptyTranslator() throws IOException { + var entity = new InputTypeTranslator(Map.of(), null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "input_type": { + "translation": {}, + "default": "" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return InputTypeTranslator::new; + } + + @Override + protected InputTypeTranslator createTestInstance() { + return createRandom(); + } + + @Override + protected InputTypeTranslator mutateInstance(InputTypeTranslator instance) throws IOException { + return randomValueOtherThan(instance, InputTypeTranslatorTests::createRandom); + } + + public static Map getTaskSettingsMap(@Nullable Map parameters) { + var map = new HashMap(); + if (parameters != null) { + map.put(CustomTaskSettings.PARAMETERS, parameters); + } + + return map; + } + + @Override + protected InputTypeTranslator mutateInstanceForVersion(InputTypeTranslator instance, TransportVersion version) { + return instance; + } +} From 6988b7a29bb8c77fb3c10fe095fbdddf12f404ed Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 13 Jun 2025 15:41:05 -0400 Subject: [PATCH 02/13] Working tests --- .../services/custom/CustomRequestManager.java | 11 +- .../services/custom/InputTypeTranslator.java | 7 +- .../custom/request/CompletionParameters.java | 4 +- .../custom/request/EmbeddingParameters.java | 50 ++++++++ .../custom/request/RequestParameters.java | 4 +- .../custom/request/RerankParameters.java | 4 +- .../custom/CustomServiceSettingsTests.java | 55 +++++++++ .../request/CompletionParametersTests.java | 35 ++++++ .../custom/request/CustomRequestTests.java | 113 ++++++++++++++++-- .../request/EmbeddingParametersTests.java | 39 ++++++ .../custom/request/RerankParametersTests.java | 33 +++++ .../response/CustomResponseEntityTests.java | 48 ++++---- 12 files changed, 357 insertions(+), 46 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParametersTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParametersTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java index 1095b29a64306..e8f80565ce16c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -25,11 +25,11 @@ import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters; import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; +import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters; import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters; import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseEntity; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -68,17 +68,16 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - String query; - List input; RequestParameters requestParameters; if (inferenceInputs instanceof QueryAndDocsInputs) { requestParameters = RerankParameters.of(QueryAndDocsInputs.of(inferenceInputs)); } else if (inferenceInputs instanceof ChatCompletionInput chatInputs) { requestParameters = CompletionParameters.of(chatInputs); } else if (inferenceInputs instanceof EmbeddingsInput) { - EmbeddingsInput embeddingsInput = EmbeddingsInput.of(inferenceInputs); - query = null; - input = embeddingsInput.getStringInputs(); + requestParameters = EmbeddingParameters.of( + EmbeddingsInput.of(inferenceInputs), + model.getServiceSettings().getInputTypeTranslator() + ); } else { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java index 5a3400bde0fba..c6a1e552eb279 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.TreeMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEmptyString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; @@ -75,7 +76,7 @@ public InputTypeTranslator(StreamInput in) throws IOException { this.defaultValue = in.readString(); } - public Map getInputTypeTranslation() { + public Map getTranslation() { return inputTypeTranslation; } @@ -85,10 +86,12 @@ public String getDefaultValue() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + var sortedMap = new TreeMap<>(inputTypeTranslation); + builder.startObject(INPUT_TYPE_TRANSLATOR); { builder.startObject(TRANSLATION); - for (var entry : inputTypeTranslation.entrySet()) { + for (var entry : sortedMap.entrySet()) { builder.field(entry.getKey().toString(), entry.getValue()); } builder.endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java index d12a95187b4db..a1ec2a3515131 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java @@ -26,10 +26,12 @@ private CompletionParameters(ChatCompletionInput completionInput) { @Override public Map jsonParameters() { - String jsonRep = toJson(inputs, INPUT); + String jsonRep; if (inputs.isEmpty() == false) { jsonRep = toJson(inputs.get(0), INPUT); + } else { + jsonRep = toJson("", INPUT); } return Map.of(INPUT, jsonRep); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java new file mode 100644 index 0000000000000..71eb0ed8e098d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; + +public class EmbeddingParameters extends RequestParameters { + private static final String INPUT_TYPE = "input_type"; + + public static EmbeddingParameters of(EmbeddingsInput embeddingsInput, InputTypeTranslator inputTypeTranslator) { + return new EmbeddingParameters(Objects.requireNonNull(embeddingsInput), Objects.requireNonNull(inputTypeTranslator)); + } + + private final InputType inputType; + private final InputTypeTranslator translator; + + private EmbeddingParameters(EmbeddingsInput embeddingsInput, InputTypeTranslator translator) { + super(embeddingsInput.getStringInputs()); + this.inputType = embeddingsInput.getInputType(); + this.translator = translator; + } + + @Override + protected Map taskTypeParameters() { + var additionalParameters = new HashMap(); + + if (inputType != null && translator.getTranslation().containsKey(inputType)) { + var inputTypeTranslation = translator.getTranslation().get(inputType); + + additionalParameters.put(INPUT_TYPE, toJson(inputTypeTranslation, INPUT_TYPE)); + } else { + additionalParameters.put(INPUT_TYPE, toJson(translator.getDefaultValue(), INPUT_TYPE)); + } + + return additionalParameters; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java index a3c877f61d61f..b03acc39c49d0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java @@ -25,14 +25,14 @@ public RequestParameters(List inputs) { } Map jsonParameters() { - var additionalParameters = childParameters(); + var additionalParameters = taskTypeParameters(); var totalParameters = new HashMap<>(additionalParameters); totalParameters.put(INPUT, toJson(inputs, INPUT)); return totalParameters; } - protected Map childParameters() { + protected Map taskTypeParameters() { return Map.of(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java index b4ae8d04e3f88..a6503c3ef65dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java @@ -33,9 +33,9 @@ private RerankParameters(QueryAndDocsInputs queryAndDocsInputs) { } @Override - protected Map childParameters() { + protected Map taskTypeParameters() { var additionalParameters = new HashMap(); - additionalParameters.put(QUERY, queryAndDocsInputs.getQuery()); + additionalParameters.put(QUERY, toJson(queryAndDocsInputs.getQuery(), QUERY)); if (queryAndDocsInputs.getTopN() != null) { additionalParameters.put( InferenceAction.Request.TOP_N.getPreferredName(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 1bb3d44b897c8..25efee0cf4f02 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; @@ -698,6 +699,60 @@ public void testXContent() throws IOException { "path": "$.error.message" } }, + "input_type": { + "translation": {}, + "default": "" + }, + "rate_limit": { + "requests_per_minute": 10000 + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_WithInputTypeTranslationValues() throws IOException { + var entity = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.abc.com", + Map.of("key", "value"), + null, + "string", + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + null, + new ErrorResponseParser("$.error.message", "inference_id"), + new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default") + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "url": "http://www.abc.com", + "headers": { + "key": "value" + }, + "request": { + "content": "string" + }, + "response": { + "json_parser": { + "text_embeddings": "$.result.embeddings[*].embedding" + }, + "error_parser": { + "path": "$.error.message" + } + }, + "input_type": { + "translation": { + "ingest": "do_ingest", + "search": "do_search" + }, + "default": "a_default" + }, "rate_limit": { "requests_per_minute": 10000 } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParametersTests.java new file mode 100644 index 0000000000000..eb2ee9c3d6d27 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParametersTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; + +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.custom.request.RequestParameters.INPUT; +import static org.hamcrest.Matchers.is; + +public class CompletionParametersTests extends ESTestCase { + + public void testJsonParameters_SingleValue() { + var parameters = CompletionParameters.of(new ChatCompletionInput(List.of("hello"))); + assertThat(parameters.jsonParameters(), is(Map.of(INPUT, "\"hello\""))); + } + + public void testJsonParameters_RetrievesFirstEntryFromList() { + var parameters = CompletionParameters.of(new ChatCompletionInput(List.of("hello", "hi"))); + assertThat(parameters.jsonParameters(), is(Map.of(INPUT, "\"hello\""))); + } + + public void testJsonParameters_EmptyList() { + var parameters = CompletionParameters.of(new ChatCompletionInput(List.of())); + assertThat(parameters.jsonParameters(), is(Map.of(INPUT, "\"\""))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 06bfc0b1f6956..5cf740ef0d683 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -13,14 +13,18 @@ import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings; import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; +import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator; import org.elasticsearch.xpack.inference.services.custom.QueryParameters; import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; @@ -47,7 +51,8 @@ public void testCreateRequest() throws IOException { Map headers = Map.of(HttpHeaders.AUTHORIZATION, Strings.format("${api_key}")); var requestContentString = """ { - "input": ${input} + "input": ${input}, + "input_type": ${input_type} } """; @@ -64,7 +69,8 @@ public void testCreateRequest() throws IOException { requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new ErrorResponseParser("$.error.message", inferenceId), + new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") ); var model = CustomModelTests.createModel( @@ -75,7 +81,13 @@ public void testCreateRequest() throws IOException { new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) ); - var request = new CustomRequest(null, List.of("abc", "123"), model); + var request = new CustomRequest( + EmbeddingParameters.of( + new EmbeddingsInput(List.of("abc", "123"), null, null), + model.getServiceSettings().getInputTypeTranslator() + ), + model + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -86,18 +98,20 @@ public void testCreateRequest() throws IOException { var expectedBody = XContentHelper.stripWhitespace(""" { - "input": ["abc", "123"] + "input": ["abc", "123"], + "input_type": "default" } """); assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); } - public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { + public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOException { var inferenceId = "inferenceId"; var requestContentString = """ { - "input": ${input} + "input": ${input}, + "input_type": ${input_type} } """; @@ -118,7 +132,8 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new ErrorResponseParser("$.error.message", inferenceId), + new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") ); var model = CustomModelTests.createModel( @@ -129,7 +144,13 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) ); - var request = new CustomRequest(null, List.of("abc", "123"), model); + var request = new CustomRequest( + EmbeddingParameters.of( + new EmbeddingsInput(List.of("abc", "123"), null, InputType.INGEST), + model.getServiceSettings().getInputTypeTranslator() + ), + model + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -139,6 +160,14 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { // To visually verify that this is correct, input the query parameters into here: https://www.urldecoder.org/ is("http://www.elastic.co?key=+%3C%3E%23%25%2B%7B%7D%7C%5C%5E%7E%5B%5D%60%3B%2F%3F%3A%40%3D%26%24&key=%CE%A3+%F0%9F%98%80") ); + + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"], + "input_type": "value" + } + """); + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); } public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws IOException { @@ -177,7 +206,13 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) ); - var request = new CustomRequest(null, List.of("abc", "123"), model); + var request = new CustomRequest( + EmbeddingParameters.of( + new EmbeddingsInput(List.of("abc", "123"), null, InputType.SEARCH), + model.getServiceSettings().getInputTypeTranslator() + ), + model + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -225,7 +260,7 @@ public void testCreateRequest_HandlesQuery() throws IOException { new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) ); - var request = new CustomRequest("query string", List.of("abc", "123"), model); + var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -241,6 +276,57 @@ public void testCreateRequest_HandlesQuery() throws IOException { assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); } + public void testCreateRequest_HandlesQuery_WithReturnDocsAndTopN() throws IOException { + var inferenceId = "inference_id"; + var requestContentString = """ + { + "input": ${input}, + "query": ${query}, + "return_documents": ${return_documents}, + "top_n": ${top_n} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.elastic.co", + null, + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of()), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest( + RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"), false, 2, false)), + model + ); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"], + "query": "query string", + "return_documents": false, + "top_n": 2 + } + """); + + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); + } + public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IOException { var inferenceId = "inference_id"; var requestContentString = """ @@ -268,7 +354,7 @@ public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IO new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) ); - var request = new CustomRequest(null, List.of("abc", "123"), model); + var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model); var exception = expectThrows(IllegalStateException.class, request::createHttpRequest); assertThat(exception.getMessage(), is("Found placeholder [${task.key}] in field [header.Accept] after replacement call")); } @@ -300,7 +386,10 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() { new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) ); - var exception = expectThrows(IllegalStateException.class, () -> new CustomRequest(null, List.of("abc", "123"), model)); + var exception = expectThrows( + IllegalStateException.class, + () -> new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model) + ); assertThat(exception.getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^")); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java new file mode 100644 index 0000000000000..5231802ef4b92 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator; + +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class EmbeddingParametersTests extends ESTestCase { + + public void testTaskTypeParameters_UsesDefaultValue() { + var parameters = EmbeddingParameters.of( + new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new InputTypeTranslator(Map.of(), "default") + ); + + assertThat(parameters.taskTypeParameters(), is(Map.of("input_type", "\"default\""))); + } + + public void testTaskTypeParameters_UsesMappedValue() { + var parameters = EmbeddingParameters.of( + new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new InputTypeTranslator(Map.of(InputType.INGEST, "ingest_value"), "default") + ); + + assertThat(parameters.taskTypeParameters(), is(Map.of("input_type", "\"ingest_value\""))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParametersTests.java new file mode 100644 index 0000000000000..26a5b9dd5b501 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParametersTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; + +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class RerankParametersTests extends ESTestCase { + + public void testTaskTypeParameters() { + var queryAndDocsInputs = new QueryAndDocsInputs("query_value", List.of("doc1", "doc2"), true, 5, false); + var parameters = RerankParameters.of(queryAndDocsInputs); + + assertThat(parameters.taskTypeParameters(), is(Map.of("query", "\"query_value\"", "top_n", "5", "return_documents", "true"))); + } + + public void testTaskTypeParameters_WithoutOptionalFields() { + var queryAndDocsInputs = new QueryAndDocsInputs("query_value", List.of("doc1", "doc2")); + var parameters = RerankParameters.of(queryAndDocsInputs); + + assertThat(parameters.taskTypeParameters(), is(Map.of("query", "\"query_value\""))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java index e7f6a47e7c9c7..56aca095a661a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -17,8 +17,14 @@ import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters; import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; +import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters; +import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -53,10 +59,13 @@ public void testFromTextEmbeddingResponse() throws IOException { } """; + var model = CustomModelTests.getTestModel( + TaskType.TEXT_EMBEDDING, + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding") + ); var request = new CustomRequest( - null, - List.of("abc"), - CustomModelTests.getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")) + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + model ); InferenceServiceResults results = CustomResponseEntity.fromResponse( request, @@ -98,17 +107,18 @@ public void testFromSparseEmbeddingResponse() throws IOException { } """; - var request = new CustomRequest( - null, - List.of("abc"), - CustomModelTests.getTestModel( - TaskType.SPARSE_EMBEDDING, - new SparseEmbeddingResponseParser( - "$.result.sparse_embeddings[*].embedding[*].tokenId", - "$.result.sparse_embeddings[*].embedding[*].weight" - ) + var model = CustomModelTests.getTestModel( + TaskType.SPARSE_EMBEDDING, + new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].tokenId", + "$.result.sparse_embeddings[*].embedding[*].weight" ) ); + var request = new CustomRequest( + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + model + + ); InferenceServiceResults results = CustomResponseEntity.fromResponse( request, @@ -152,14 +162,11 @@ public void testFromRerankResponse() throws IOException { } """; - var request = new CustomRequest( - null, - List.of("abc"), - CustomModelTests.getTestModel( - TaskType.RERANK, - new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null) - ) + var model = CustomModelTests.getTestModel( + TaskType.RERANK, + new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null) ); + var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query", List.of("doc1", "doc2"))), model); InferenceServiceResults results = CustomResponseEntity.fromResponse( request, @@ -193,8 +200,7 @@ public void testFromCompletionResponse() throws IOException { """; var request = new CustomRequest( - null, - List.of("abc"), + CompletionParameters.of(new ChatCompletionInput(List.of("abc"))), CustomModelTests.getTestModel(TaskType.COMPLETION, new CompletionResponseParser("$.result.text")) ); From 4bc337ba2e032b30dff0f2bf7fea335eb33d942e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 13 Jun 2025 16:59:49 -0400 Subject: [PATCH 03/13] Adding custom service validator for rerank --- server/src/main/java/module-info.java | 1 + .../inference/InferenceService.java | 11 +++ .../CustomServiceIntegrationValidator.java | 68 ++++++++++++++++++ .../ServiceIntegrationValidator.java | 2 +- .../TransportPutInferenceModelAction.java | 2 +- .../services/custom/CustomService.java | 11 +++ .../ChatCompletionModelValidator.java | 1 + ...icsearchInternalServiceModelValidator.java | 1 + .../validation/ModelValidatorBuilder.java | 32 ++++++--- ...CompletionServiceIntegrationValidator.java | 1 + .../validation/SimpleModelValidator.java | 1 + .../SimpleServiceIntegrationValidator.java | 1 + .../TextEmbeddingModelValidator.java | 1 + .../services/custom/CustomServiceTests.java | 2 +- .../ChatCompletionModelValidatorTests.java | 1 + .../ModelValidatorBuilderTests.java | 71 ++++++++++++++++++- .../validation/SimpleModelValidatorTests.java | 1 + .../TextEmbeddingModelValidatorTests.java | 1 + 18 files changed, 196 insertions(+), 13 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/inference/validation/CustomServiceIntegrationValidator.java rename {x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services => server/src/main/java/org/elasticsearch/inference}/validation/ServiceIntegrationValidator.java (91%) diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 8da4f403c29bd..c4d88967ebd93 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -473,6 +473,7 @@ org.elasticsearch.serverless.apifiltering; exports org.elasticsearch.lucene.spatial; exports org.elasticsearch.inference.configuration; + exports org.elasticsearch.inference.validation; exports org.elasticsearch.monitor.metrics; exports org.elasticsearch.plugins.internal.rewriter to org.elasticsearch.inference; exports org.elasticsearch.lucene.util.automaton; diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index e3e9abf7dc3f2..03d2806005788 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -14,6 +14,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import java.io.Closeable; import java.util.EnumSet; @@ -248,4 +249,14 @@ default void updateModelsWithDynamicFields(List model, ActionListener TEST_INPUT = List.of("how big"); + private static final String QUERY = "test query"; + + @Override + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + service.infer( + model, + model.getTaskType().equals(TaskType.RERANK) ? QUERY : null, + true, + 1, + TEST_INPUT, + false, + Map.of(), + InputType.INTERNAL_INGEST, + timeout, + ActionListener.wrap(r -> { + if (r != null) { + listener.onResponse(r); + } else { + listener.onFailure( + new ElasticsearchStatusException( + "Could not complete custom service inference endpoint creation as" + + " validation call to service returned null response.", + RestStatus.BAD_REQUEST + ) + ); + } + }, + e -> listener.onFailure( + new ElasticsearchStatusException( + "Could not complete custom service inference endpoint creation as validation call to service threw an exception.", + RestStatus.BAD_REQUEST, + e + ) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java b/server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java similarity index 91% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java rename to server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java index 49ade6c00fb22..93fcdb3017829 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java +++ b/server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.validation; +package org.elasticsearch.inference.validation; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index bc9d87f43ada0..de8407c259c58 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -212,7 +212,7 @@ private void parseAndStoreModel( if (skipValidationAndStart) { storeModelListener.onResponse(model); } else { - ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService) + ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service) .validate(service, model, timeout, storeModelListener); } }); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 5e9aef099f622..cba23b47d4e8d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -26,6 +26,8 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.CustomServiceIntegrationValidator; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -276,4 +278,13 @@ public static InferenceServiceConfiguration get() { } ); } + + @Override + public ServiceIntegrationValidator getServiceIntegrationValidator(TaskType taskType) { + if (taskType == TaskType.RERANK) { + return new CustomServiceIntegrationValidator(); + } + + return null; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java index 624f223c9f3e1..860e300a1c197 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java @@ -11,6 +11,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; public class ChatCompletionModelValidator implements ModelValidator { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java index 8fdb511ab31c6..fa0e1b3e590a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java @@ -15,6 +15,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index d9eed10268527..fac9ee5e9c1c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -8,34 +8,50 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; + +import java.util.Objects; public class ModelValidatorBuilder { - public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) { - if (isElasticsearchInternalService) { + public static ModelValidator buildModelValidator(TaskType taskType, InferenceService service) { + if (service instanceof ElasticsearchInternalService) { return new ElasticsearchInternalServiceModelValidator(new SimpleServiceIntegrationValidator()); } else { - return buildModelValidatorForTaskType(taskType); + return buildModelValidatorForTaskType(taskType, service); } } - private static ModelValidator buildModelValidatorForTaskType(TaskType taskType) { + private static ModelValidator buildModelValidatorForTaskType(TaskType taskType, InferenceService service) { if (taskType == null) { throw new IllegalArgumentException("Task type can't be null"); } + ServiceIntegrationValidator validatorFromService = null; + if (service != null) { + validatorFromService = service.getServiceIntegrationValidator(taskType); + } + switch (taskType) { case TEXT_EMBEDDING -> { - return new TextEmbeddingModelValidator(new SimpleServiceIntegrationValidator()); + return new TextEmbeddingModelValidator( + Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator()) + ); } case COMPLETION -> { - return new ChatCompletionModelValidator(new SimpleServiceIntegrationValidator()); + return new ChatCompletionModelValidator( + Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator()) + ); } case CHAT_COMPLETION -> { - return new ChatCompletionModelValidator(new SimpleChatCompletionServiceIntegrationValidator()); + return new ChatCompletionModelValidator( + Objects.requireNonNullElse(validatorFromService, new SimpleChatCompletionServiceIntegrationValidator()) + ); } case SPARSE_EMBEDDING, RERANK, ANY -> { - return new SimpleModelValidator(new SimpleServiceIntegrationValidator()); + return new SimpleModelValidator(Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator())); } default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model for task type %s", taskType)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java index f9cf67172bc2a..61ec93e541924 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java index 3d592840f533b..41bdfc95dc5a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java @@ -11,6 +11,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; public class SimpleModelValidator implements ModelValidator { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java index 03ac5b95fddc5..d2cb1925ad7d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java index bff04f5af2d75..ce9df7376ebcb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index fb6c50f1bd9c4..8f4ca8ff55b58 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -135,7 +135,7 @@ private static void assertCompletionModel(Model model) { assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(CompletionResponseParser.class)); } - private static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); return new CustomService(senderFactory, createWithEmptySettings(threadPool)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java index bd52bdd52ab3e..03214b770bdc7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.test.ESTestCase; import org.junit.Before; import org.mockito.Mock; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java index 19ea0bedaaea5..e26ffa137c047 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java @@ -7,21 +7,88 @@ package org.elasticsearch.xpack.inference.services.validation; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceTests; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.junit.After; +import org.junit.Before; +import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.isA; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; public class ModelValidatorBuilderTests extends ESTestCase { + + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + clientManager.close(); + terminate(threadPool); + } + + public void testCustomServiceValidator() { + var service = CustomServiceTests.createService(threadPool, clientManager); + var validator = ModelValidatorBuilder.buildModelValidator(TaskType.RERANK, service); + var mockService = mock(InferenceService.class); + validator.validate( + mockService, + CustomModelTests.getTestModel(TaskType.RERANK, new RerankResponseParser("score")), + null, + ActionListener.noop() + ); + + verify(mockService, times(1)).infer( + any(), + eq("test query"), + eq(true), + eq(1), + eq(List.of("how big")), + eq(false), + eq(Map.of()), + eq(InputType.INTERNAL_INGEST), + any(), + any() + ); + verifyNoMoreInteractions(mockService); + } + public void testBuildModelValidator_NullTaskType() { - assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null, false); }); + assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null, null); }); } public void testBuildModelValidator_ValidTaskType() { taskTypeToModelValidatorClassMap().forEach((taskType, modelValidatorClass) -> { - assertThat(ModelValidatorBuilder.buildModelValidator(taskType, false), isA(modelValidatorClass)); + assertThat(ModelValidatorBuilder.buildModelValidator(taskType, null), isA(modelValidatorClass)); }); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java index 418584d25d085..a793d5ec088b7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.junit.Before; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java index 55a02ebab082b..45726f0789667 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; From a3a532b8533dbf5c4e374ccc84023b5910b935f8 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 13 Jun 2025 22:20:09 -0400 Subject: [PATCH 04/13] Fixing embedding bug --- .../services/custom/CustomService.java | 15 ++++++++++++--- .../services/custom/CustomServiceSettings.java | 2 ++ .../services/custom/request/CustomRequest.java | 18 ------------------ 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index cba23b47d4e8d..49d7dc79a399d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -37,7 +37,6 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.util.EnumSet; import java.util.HashMap; @@ -201,7 +200,16 @@ public void doInfer( @Override protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { - ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + // We will support all values. If one of the values is not included in the translation map, we'll use the default provided instead + // if (model instanceof CustomModel customModel) { + // var translationKeys = customModel.getServiceSettings().getInputTypeTranslator().getTranslation().keySet(); + // var supportedInputTypes = translationKeys.isEmpty() ? EnumSet.noneOf(InputType.class) : EnumSet.copyOf(translationKeys); + // ServiceUtils.validateInputTypeAgainstAllowlist(inputType, supportedInputTypes, SERVICE_NAME, validationException); + // } else { + // validationException.addValidationError( + // Strings.format("Model of type [%s] is not supported by the %s service", model.getClass().getSimpleName(), NAME) + // ); + // } } @Override @@ -252,7 +260,8 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom serviceSettings.getRequestContentString(), serviceSettings.getResponseJsonParser(), serviceSettings.rateLimitSettings(), - serviceSettings.getErrorParser() + serviceSettings.getErrorParser(), + serviceSettings.getInputTypeTranslator() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index a7c02ba37456c..4db438c0e89e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -278,6 +278,7 @@ public CustomServiceSettings(StreamInput in) throws IOException { responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); rateLimitSettings = new RateLimitSettings(in); errorParser = new ErrorResponseParser(in); + // TODO need a new transport version inputTypeTranslator = new InputTypeTranslator(in); } @@ -409,6 +410,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(responseJsonParser); rateLimitSettings.writeTo(out); errorParser.writeTo(out); + // TODO need a new transport version inputTypeTranslator.writeTo(out); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java index c3f4d5bc64ffe..9ea0ac2a3182a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java @@ -14,7 +14,6 @@ import org.apache.http.entity.StringEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.ValidatingSubstitutor; import org.elasticsearch.xpack.inference.external.request.HttpRequest; @@ -27,7 +26,6 @@ import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Objects; @@ -36,8 +34,6 @@ import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.URL; public class CustomRequest implements Request { - private static final String QUERY = "query"; - private static final String INPUT = "input"; private final URI uri; private final ValidatingSubstitutor jsonPlaceholderReplacer; @@ -55,12 +51,6 @@ public CustomRequest(RequestParameters requestParams, CustomModel model) { addJsonStringParams(jsonParams, model.getSecretSettings().getSecretParameters()); addJsonStringParams(jsonParams, model.getTaskSettings().getParameters()); - // if (query != null) { - // jsonParams.put(QUERY, toJson(query, QUERY)); - // } - - // addInputJsonParam(jsonParams, input, model.getTaskType()); - jsonParams.putAll(requestParams.jsonParameters()); jsonPlaceholderReplacer = new ValidatingSubstitutor(jsonParams, "${", "}"); @@ -86,14 +76,6 @@ private static void addJsonStringParams(Map jsonStringParams, Ma } } - private static void addInputJsonParam(Map jsonParams, List input, TaskType taskType) { - if (taskType == TaskType.COMPLETION && input.isEmpty() == false) { - jsonParams.put(INPUT, toJson(input.get(0), INPUT)); - } else { - jsonParams.put(INPUT, toJson(input, INPUT)); - } - } - private URI buildUri() { var replacedUrl = stringPlaceholderReplacer.replace(model.getServiceSettings().getUrl(), URL); From 6176e3729e9aa02f737fd40247dd47e1a706b59a Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 16 Jun 2025 09:40:46 -0400 Subject: [PATCH 05/13] Adding transport version check --- .../org/elasticsearch/TransportVersions.java | 2 ++ .../custom/CustomServiceSettings.java | 8 ++++-- .../services/AbstractServiceTests.java | 28 ------------------- 3 files changed, 8 insertions(+), 30 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 57ba6de0b973c..4d47f109c0f9e 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -184,6 +184,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37); public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_38); public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED_8_19 = def(8_841_0_39); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_40); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -273,6 +274,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED = def(9_084_0_00); public static final TransportVersion ESQL_LIMIT_ROW_SIZE = def(9_085_0_00); public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY = def(9_086_0_00); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_087_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 4db438c0e89e5..cad5f52572cd6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -278,8 +278,12 @@ public CustomServiceSettings(StreamInput in) throws IOException { responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); rateLimitSettings = new RateLimitSettings(in); errorParser = new ErrorResponseParser(in); - // TODO need a new transport version - inputTypeTranslator = new InputTypeTranslator(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE) + || in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19)) { + inputTypeTranslator = new InputTypeTranslator(in); + } else { + inputTypeTranslator = InputTypeTranslator.EMPTY_TRANSLATOR; + } } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java index 24e0c7cadb73f..01b0786ecdd8d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; @@ -460,33 +459,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { } } - public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { - var listener = new PlainActionFuture(); - - var exception = expectThrows( - ValidationException.class, - () -> service.infer( - getInvalidModel("id", "service"), - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) - ); - - assertThat( - exception.getMessage(), - is("Validation Failed: 1: Invalid input_type [ingest]. The input_type option is not supported by this service;") - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); From 02ffef665e32ecb69ba28975e45b5d77ec32c0b5 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 16 Jun 2025 10:35:26 -0400 Subject: [PATCH 06/13] Fixing tests --- .../inference/services/custom/CustomServiceSettingsTests.java | 4 +--- .../inference/services/custom/request/CustomRequestTests.java | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index ef757523ea3e7..35a487e72aae2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -687,9 +687,7 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException { "headers": { "key": "value" }, - "request": { - "content": "string" - }, + "request": "string", "response": { "json_parser": { "text_embeddings": "$.result.embeddings[*].embedding" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 9849bad71e05b..14e6caaff668d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -303,7 +303,7 @@ public void testCreateRequest_HandlesQuery_WithReturnDocsAndTopN() throws IOExce TaskType.RERANK, serviceSettings, new CustomTaskSettings(Map.of()), - new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray()))) ); var request = new CustomRequest( From 659f9e05930f4c4e629602dc7f687206ed69163c Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 16 Jun 2025 11:10:56 -0400 Subject: [PATCH 07/13] Fixing license header --- .../validation/CustomServiceIntegrationValidator.java | 9 +++++---- .../validation/ServiceIntegrationValidator.java | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/validation/CustomServiceIntegrationValidator.java b/server/src/main/java/org/elasticsearch/inference/validation/CustomServiceIntegrationValidator.java index 699f64c239a80..e2b4c4a6224f8 100644 --- a/server/src/main/java/org/elasticsearch/inference/validation/CustomServiceIntegrationValidator.java +++ b/server/src/main/java/org/elasticsearch/inference/validation/CustomServiceIntegrationValidator.java @@ -1,9 +1,10 @@ - /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ package org.elasticsearch.inference.validation; diff --git a/server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java b/server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java index 93fcdb3017829..83534f1320cff 100644 --- a/server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java +++ b/server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java @@ -1,8 +1,10 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ package org.elasticsearch.inference.validation; From b061fe2257a75b51709975cffefde61837512633 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 16 Jun 2025 13:09:19 -0400 Subject: [PATCH 08/13] Fixing writeTo --- .../inference/services/custom/CustomServiceSettings.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 8d8e40d46b2b9..e3ef7c35e59f7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -400,8 +400,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(responseJsonParser); rateLimitSettings.writeTo(out); errorParser.writeTo(out); - // TODO need a new transport version - inputTypeTranslator.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE) + || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19)) { + inputTypeTranslator.writeTo(out); + } } @Override From a65a12b5d2203bad54ca392434f972f783154529 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 17 Jun 2025 16:54:42 -0400 Subject: [PATCH 09/13] Moving file and removing commented code --- .../inference/services/custom/CustomService.java | 14 +++----------- .../CustomServiceIntegrationValidator.java | 11 +++++------ 2 files changed, 8 insertions(+), 17 deletions(-) rename {server/src/main/java/org/elasticsearch/inference => x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services}/validation/CustomServiceIntegrationValidator.java (85%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 49d7dc79a399d..7c041757cd7a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -26,7 +26,6 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.validation.CustomServiceIntegrationValidator; import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -37,6 +36,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.validation.CustomServiceIntegrationValidator; import java.util.EnumSet; import java.util.HashMap; @@ -200,16 +200,8 @@ public void doInfer( @Override protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { - // We will support all values. If one of the values is not included in the translation map, we'll use the default provided instead - // if (model instanceof CustomModel customModel) { - // var translationKeys = customModel.getServiceSettings().getInputTypeTranslator().getTranslation().keySet(); - // var supportedInputTypes = translationKeys.isEmpty() ? EnumSet.noneOf(InputType.class) : EnumSet.copyOf(translationKeys); - // ServiceUtils.validateInputTypeAgainstAllowlist(inputType, supportedInputTypes, SERVICE_NAME, validationException); - // } else { - // validationException.addValidationError( - // Strings.format("Model of type [%s] is not supported by the %s service", model.getClass().getSimpleName(), NAME) - // ); - // } + // The custom service doesn't do any validation for the input type because if the input type is supported a default + // must be supplied within the service settings. } @Override diff --git a/server/src/main/java/org/elasticsearch/inference/validation/CustomServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java similarity index 85% rename from server/src/main/java/org/elasticsearch/inference/validation/CustomServiceIntegrationValidator.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java index e2b4c4a6224f8..aee1ed3ec4ebc 100644 --- a/server/src/main/java/org/elasticsearch/inference/validation/CustomServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java @@ -1,13 +1,11 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.inference.validation; +package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; @@ -17,6 +15,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import java.util.List; From 70be427e92e927b4efd219273871e817f9ccdae9 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 18 Jun 2025 11:54:58 -0400 Subject: [PATCH 10/13] Fixing test --- .../inference/services/custom/CustomServiceSettingsTests.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 3b4a8b9cbc8f0..276a1cf856a15 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -600,7 +600,6 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException { "string", new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), null, - new ErrorResponseParser("$.error.message", "inference_id"), new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default") ); @@ -618,9 +617,6 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException { "response": { "json_parser": { "text_embeddings": "$.result.embeddings[*].embedding" - }, - "error_parser": { - "path": "$.error.message" } }, "input_type": { From 5c39685e565151773ff4c00827c6406207cc248c Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 18 Jun 2025 15:58:42 -0400 Subject: [PATCH 11/13] Fixing tests --- .../custom/CustomServiceSettingsTests.java | 14 +++++++++++--- .../services/custom/CustomServiceTests.java | 3 ++- .../custom/request/CustomRequestTests.java | 2 ++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index da91be94cf758..d25db835fae04 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -158,7 +158,8 @@ public void testFromMap() { requestContentString, responseParser, new RateLimitSettings(10_000), - 11 + 11, + InputTypeTranslator.EMPTY_TRANSLATOR ) ) ); @@ -587,7 +588,8 @@ public void testXContent() throws IOException { }, "rate_limit": { "requests_per_minute": 10000 - } + }, + "batch_size": 10 } """); @@ -603,6 +605,7 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException { "string", new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), null, + null, new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default") ); @@ -648,7 +651,8 @@ public void testXContent_BatchSize11() throws IOException { "string", new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), null, - 11 + 11, + InputTypeTranslator.EMPTY_TRANSLATOR ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -667,6 +671,10 @@ public void testXContent_BatchSize11() throws IOException { "text_embeddings": "$.result.embeddings[*].embedding" } }, + "input_type": { + "translation": {}, + "default": "" + }, "rate_limit": { "requests_per_minute": 10000 }, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 66c147394a47c..69564489632cb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -276,7 +276,8 @@ private static CustomModel createInternalEmbeddingModel( "{\"input\":${input}}", parser, new RateLimitSettings(10_000), - batchSize + batchSize, + InputTypeTranslator.EMPTY_TRANSLATOR ), new CustomTaskSettings(Map.of("key", "test_value")), new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 492244677063b..5eadf2f9fe88f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -68,6 +68,7 @@ public void testCreateRequest() throws IOException { requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), new RateLimitSettings(10_000), + null, new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") ); @@ -130,6 +131,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOEx requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), new RateLimitSettings(10_000), + null, new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") ); From 9d5c2e5d9b020d567f0b5308a35edcc05fc1ded0 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 20 Jun 2025 09:29:33 -0400 Subject: [PATCH 12/13] Refactoring and tests --- .../elasticsearch/inference/InputType.java | 21 ++- .../xpack/inference/InputTypeTests.java | 131 ++++++++++++++++++ 2 files changed, 148 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InputType.java b/server/src/main/java/org/elasticsearch/inference/InputType.java index 46d148d04b6a4..c930acdc0f45e 100644 --- a/server/src/main/java/org/elasticsearch/inference/InputType.java +++ b/server/src/main/java/org/elasticsearch/inference/InputType.java @@ -94,9 +94,11 @@ public static Map validateInputTypeTranslationValues( if (value instanceof String == false || Strings.isNullOrEmpty((String) value)) { validationException.addValidationError( Strings.format( - "Input type translation value for key [%s] must be a String that is not null and not empty, received: [%s].", + "Input type translation value for key [%s] must be a String that is " + + "not null and not empty, received: [%s], type: [%s].", key, - value.getClass().getSimpleName() + value, + value == null ? "null" : value.getClass().getSimpleName() ) ); @@ -104,14 +106,14 @@ public static Map validateInputTypeTranslationValues( } try { - var inputTypeKey = InputType.fromRestString(key); + var inputTypeKey = InputType.fromStringValidateSupportedRequestValue(key); translationMap.put(inputTypeKey, (String) value); } catch (Exception e) { validationException.addValidationError( Strings.format( "Invalid input type translation for key: [%s], is not a valid value. Must be one of %s", key, - EnumSet.of(InputType.CLASSIFICATION, InputType.CLUSTERING, InputType.INGEST, InputType.SEARCH) + SUPPORTED_REQUEST_VALUES ) ); @@ -121,4 +123,15 @@ public static Map validateInputTypeTranslationValues( return translationMap; } + + private static InputType fromStringValidateSupportedRequestValue(String name) { + var inputType = fromRestString(name); + if (SUPPORTED_REQUEST_VALUES.contains(inputType) == false) { + throw new IllegalArgumentException( + format("Unrecognized input_type [%s], must be one of %s", inputType, SUPPORTED_REQUEST_VALUES) + ); + } + + return inputType; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java index cd51de0b57125..c6e08a8d5bdf3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java @@ -7,10 +7,13 @@ package org.elasticsearch.xpack.inference; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.elasticsearch.core.Strings.format; import static org.hamcrest.CoreMatchers.is; @@ -80,4 +83,132 @@ public void testFromRestString_ThrowsErrorForInvalidInputTypes() { assertThat(thrownException.getMessage(), is("No enum constant org.elasticsearch.inference.InputType.FOO")); } + + public void testValidateInputTypeTranslationValues() { + assertThat( + InputType.validateInputTypeTranslationValues( + Map.of( + InputType.INGEST.toString(), + "ingest_value", + InputType.SEARCH.toString(), + "search_value", + InputType.CLASSIFICATION.toString(), + "classification_value", + InputType.CLUSTERING.toString(), + "clustering_value" + ), + new ValidationException() + ), + is( + Map.of( + InputType.INGEST, + "ingest_value", + InputType.SEARCH, + "search_value", + InputType.CLASSIFICATION, + "classification_value", + InputType.CLUSTERING, + "clustering_value" + ) + ) + ); + } + + public void testValidateInputTypeTranslationValues_ReturnsEmptyMap_WhenTranslationIsNull() { + assertThat(InputType.validateInputTypeTranslationValues(null, new ValidationException()), is(Map.of())); + } + + public void testValidateInputTypeTranslationValues_ReturnsEmptyMap_WhenTranslationIsAnEmptyMap() { + assertThat(InputType.validateInputTypeTranslationValues(Map.of(), new ValidationException()), is(Map.of())); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenInputTypeIsUnspecified() { + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues( + Map.of(InputType.INGEST.toString(), "ingest_value", InputType.UNSPECIFIED.toString(), "unspecified_value"), + new ValidationException() + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Invalid input type translation for key: [unspecified], is not a valid value. Must be " + + "one of [ingest, search, classification, clustering];" + ) + ); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenInputTypeIsInternal() { + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues( + Map.of(InputType.INGEST.toString(), "ingest_value", InputType.INTERNAL_INGEST.toString(), "internal_ingest_value"), + new ValidationException() + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Invalid input type translation for key: [internal_ingest], is not a valid value. Must be " + + "one of [ingest, search, classification, clustering];" + ) + ); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIsNull() { + var translation = new HashMap(); + translation.put(InputType.INGEST.toString(), null); + + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues(translation, new ValidationException()) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [ingest] must be a String that " + + "is not null and not empty, received: [null], type: [null].;" + ) + ); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIsAnEmptyString() { + var translation = new HashMap(); + translation.put(InputType.INGEST.toString(), ""); + + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues(translation, new ValidationException()) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [ingest] must be a String that " + + "is not null and not empty, received: [], type: [String].;" + ) + ); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIsNotAString() { + var translation = new HashMap(); + translation.put(InputType.INGEST.toString(), 1); + + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues(translation, new ValidationException()) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [ingest] must be a String that " + + "is not null and not empty, received: [1], type: [Integer].;" + ) + ); + } } From 35035ad1c5399e1f4a0e653bc99130ccbd74d471 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 20 Jun 2025 10:28:48 -0400 Subject: [PATCH 13/13] Fixing test --- .../inference/services/custom/InputTypeTranslatorTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java index f005f62fd6969..0d53c11967d4c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java @@ -110,7 +110,7 @@ public void testFromMap_Throws_IfValueIsNotAString() { exception.getMessage(), is( "Validation Failed: 1: Input type translation value for key [CLASSIFICATION] " - + "must be a String that is not null and not empty, received: [Integer].;" + + "must be a String that is not null and not empty, received: [12345], type: [Integer].;" ) ); } @@ -132,7 +132,7 @@ public void testFromMap_Throws_IfValueIsEmptyString() { exception.getMessage(), is( "Validation Failed: 1: Input type translation value for key [CLASSIFICATION] " - + "must be a String that is not null and not empty, received: [String].;" + + "must be a String that is not null and not empty, received: [], type: [String].;" ) ); }