diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 7b2b6d91fdc83..a38fc29950b4b 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -475,6 +475,7 @@ org.elasticsearch.serverless.apifiltering; exports org.elasticsearch.lucene.spatial; exports org.elasticsearch.inference.configuration; + exports org.elasticsearch.inference.validation; exports org.elasticsearch.monitor.metrics; exports org.elasticsearch.plugins.internal.rewriter to org.elasticsearch.inference; exports org.elasticsearch.lucene.util.automaton; diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 96657a70a20e5..b949cb62f6a06 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -200,6 +200,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52); public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_53); public static final TransportVersion STREAMS_LOGS_SUPPORT_8_19 = def(8_841_0_54); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_55); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); @@ -308,6 +309,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00); public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00); public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_105_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index e3e9abf7dc3f2..03d2806005788 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -14,6 +14,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import java.io.Closeable; import java.util.EnumSet; @@ -248,4 +249,14 @@ default void updateModelsWithDynamicFields(List model, ActionListener SUPPORTED_REQUEST_VALUES = EnumSet.of( + InputType.CLASSIFICATION, + InputType.CLUSTERING, + InputType.INGEST, + InputType.SEARCH + ); + @Override public String toString() { return name().toLowerCase(Locale.ROOT); @@ -57,4 +68,70 @@ public static boolean isSpecified(InputType inputType) { public static String invalidInputTypeMessage(InputType inputType) { return Strings.format("received invalid input type value [%s]", inputType.toString()); } + + /** + * Ensures that a map used for translating input types is valid. The keys of the map are the external representation, + * and the values correspond to the values in this class. + * Throws a {@link ValidationException} if any value is not a valid InputType. + * + * @param inputTypeTranslation the map of input type translations to validate + * @param validationException a ValidationException to which errors will be added + */ + public static Map validateInputTypeTranslationValues( + Map inputTypeTranslation, + ValidationException validationException + ) { + if (inputTypeTranslation == null || inputTypeTranslation.isEmpty()) { + return Map.of(); + } + + var translationMap = new HashMap(); + + for (var entry : inputTypeTranslation.entrySet()) { + var key = entry.getKey(); + var value = entry.getValue(); + + if (value instanceof String == false || Strings.isNullOrEmpty((String) value)) { + validationException.addValidationError( + Strings.format( + "Input type translation value for key [%s] must be a String that is " + + "not null and not empty, received: [%s], type: [%s].", + key, + value, + value == null ? "null" : value.getClass().getSimpleName() + ) + ); + + throw validationException; + } + + try { + var inputTypeKey = InputType.fromStringValidateSupportedRequestValue(key); + translationMap.put(inputTypeKey, (String) value); + } catch (Exception e) { + validationException.addValidationError( + Strings.format( + "Invalid input type translation for key: [%s], is not a valid value. Must be one of %s", + key, + SUPPORTED_REQUEST_VALUES + ) + ); + + throw validationException; + } + } + + return translationMap; + } + + private static InputType fromStringValidateSupportedRequestValue(String name) { + var inputType = fromRestString(name); + if (SUPPORTED_REQUEST_VALUES.contains(inputType) == false) { + throw new IllegalArgumentException( + format("Unrecognized input_type [%s], must be one of %s", inputType, SUPPORTED_REQUEST_VALUES) + ); + } + + return inputType; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java b/server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java similarity index 54% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java rename to server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java index 49ade6c00fb22..83534f1320cff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java +++ b/server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java @@ -1,11 +1,13 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.xpack.inference.services.validation; +package org.elasticsearch.inference.validation; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index eb401354f8c14..80d57f888ef6e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -216,7 +216,7 @@ private void parseAndStoreModel( if (skipValidationAndStart) { storeModelListener.onResponse(model); } else { - ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService) + ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service) .validate(service, model, timeout, storeModelListener); } }); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index f8a77c4da2e85..a3e9d69654f1f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -399,6 +399,18 @@ public static String extractRequiredString( return requiredField; } + public static String extractOptionalEmptyString(Map map, String settingName, ValidationException validationException) { + int initialValidationErrorCount = validationException.validationErrors().size(); + String optionalField = ServiceUtils.removeAsType(map, settingName, String.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + // new validation error occurred + return null; + } + + return optionalField; + } + public static String extractOptionalString( Map map, String settingName, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java index fa1e753ec5eff..6e2400c33d6a2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -23,10 +23,13 @@ import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters; import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; +import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters; +import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters; +import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseEntity; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -65,19 +68,16 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - String query; - List input; + RequestParameters requestParameters; if (inferenceInputs instanceof QueryAndDocsInputs) { - QueryAndDocsInputs queryAndDocsInputs = QueryAndDocsInputs.of(inferenceInputs); - query = queryAndDocsInputs.getQuery(); - input = queryAndDocsInputs.getChunks(); + requestParameters = RerankParameters.of(QueryAndDocsInputs.of(inferenceInputs)); } else if (inferenceInputs instanceof ChatCompletionInput chatInputs) { - query = null; - input = chatInputs.getInputs(); + requestParameters = CompletionParameters.of(chatInputs); } else if (inferenceInputs instanceof EmbeddingsInput) { - EmbeddingsInput embeddingsInput = EmbeddingsInput.of(inferenceInputs); - query = null; - input = embeddingsInput.getStringInputs(); + requestParameters = EmbeddingParameters.of( + EmbeddingsInput.of(inferenceInputs), + model.getServiceSettings().getInputTypeTranslator() + ); } else { listener.onFailure( new ElasticsearchStatusException( @@ -89,7 +89,7 @@ public void execute( } try { - var request = new CustomRequest(query, input, model); + var request = new CustomRequest(requestParameters, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, handler, hasRequestCompletedFunction, listener)); } catch (Exception e) { // Intentionally not logging this exception because it could contain sensitive information from the CustomRequest construction diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index bab7192a5a7a8..deb6e17ec5311 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -27,19 +27,26 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters; import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; +import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters; +import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters; +import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters; +import org.elasticsearch.xpack.inference.services.validation.CustomServiceIntegrationValidator; import java.util.EnumSet; import java.util.HashMap; @@ -115,13 +122,8 @@ public void parseRequestConfig( * This does some initial validation with mock inputs to determine if any templates are missing a field to fill them. */ private static void validateConfiguration(CustomModel model) { - String query = null; - if (model.getTaskType() == TaskType.RERANK) { - query = "test query"; - } - try { - new CustomRequest(query, List.of("test input"), model).createHttpRequest(); + new CustomRequest(createParameters(model), model).createHttpRequest(); } catch (IllegalStateException e) { var validationException = new ValidationException(); validationException.addValidationError(Strings.format("Failed to validate model configuration: %s", e.getMessage())); @@ -129,6 +131,20 @@ private static void validateConfiguration(CustomModel model) { } } + private static RequestParameters createParameters(CustomModel model) { + return switch (model.getTaskType()) { + case RERANK -> RerankParameters.of(new QueryAndDocsInputs("test query", List.of("test input"))); + case COMPLETION -> CompletionParameters.of(new ChatCompletionInput(List.of("test input"))); + case TEXT_EMBEDDING, SPARSE_EMBEDDING -> EmbeddingParameters.of( + new EmbeddingsInput(List.of("test input"), null, null), + model.getServiceSettings().getInputTypeTranslator() + ); + default -> throw new IllegalStateException( + Strings.format("Unsupported task type [%s] for custom service", model.getTaskType()) + ); + }; + } + private static ChunkingSettings extractChunkingSettings(Map config, TaskType taskType) { if (TaskType.TEXT_EMBEDDING.equals(taskType)) { return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); @@ -257,7 +273,8 @@ public void doInfer( @Override protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { - ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + // The custom service doesn't do any validation for the input type because if the input type is supported a default + // must be supplied within the service settings. } @Override @@ -327,7 +344,9 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom serviceSettings.getQueryParameters(), serviceSettings.getRequestContentString(), serviceSettings.getResponseJsonParser(), - serviceSettings.rateLimitSettings() + serviceSettings.rateLimitSettings(), + serviceSettings.getBatchSize(), + serviceSettings.getInputTypeTranslator() ); } @@ -353,4 +372,13 @@ public static InferenceServiceConfiguration get() { } ); } + + @Override + public ServiceIntegrationValidator getServiceIntegrationValidator(TaskType taskType) { + if (taskType == TaskType.RERANK) { + return new CustomServiceIntegrationValidator(); + } + + return null; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 15caace9820fc..83048120bc545 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -110,6 +110,7 @@ public static CustomServiceSettings fromMap( context ); + var inputTypeTranslator = InputTypeTranslator.fromMap(map, validationException, CustomService.NAME); var batchSize = extractOptionalPositiveInteger(map, BATCH_SIZE, ModelConfigurations.SERVICE_SETTINGS, validationException); if (responseParserMap == null || jsonParserMap == null) { @@ -131,7 +132,8 @@ public static CustomServiceSettings fromMap( requestContentString, responseJsonParser, rateLimitSettings, - batchSize + batchSize, + inputTypeTranslator ); } @@ -203,6 +205,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private final CustomResponseParser responseJsonParser; private final RateLimitSettings rateLimitSettings; private final int batchSize; + private final InputTypeTranslator inputTypeTranslator; public CustomServiceSettings( TextEmbeddingSettings textEmbeddingSettings, @@ -213,7 +216,17 @@ public CustomServiceSettings( CustomResponseParser responseJsonParser, @Nullable RateLimitSettings rateLimitSettings ) { - this(textEmbeddingSettings, url, headers, queryParameters, requestContentString, responseJsonParser, rateLimitSettings, null); + this( + textEmbeddingSettings, + url, + headers, + queryParameters, + requestContentString, + responseJsonParser, + rateLimitSettings, + null, + InputTypeTranslator.EMPTY_TRANSLATOR + ); } public CustomServiceSettings( @@ -224,7 +237,8 @@ public CustomServiceSettings( String requestContentString, CustomResponseParser responseJsonParser, @Nullable RateLimitSettings rateLimitSettings, - @Nullable Integer batchSize + @Nullable Integer batchSize, + InputTypeTranslator inputTypeTranslator ) { this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings); this.url = Objects.requireNonNull(url); @@ -234,6 +248,7 @@ public CustomServiceSettings( this.responseJsonParser = Objects.requireNonNull(responseJsonParser); this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); this.batchSize = Objects.requireNonNullElse(batchSize, DEFAULT_EMBEDDING_BATCH_SIZE); + this.inputTypeTranslator = Objects.requireNonNull(inputTypeTranslator); } public CustomServiceSettings(StreamInput in) throws IOException { @@ -258,6 +273,13 @@ public CustomServiceSettings(StreamInput in) throws IOException { } else { batchSize = DEFAULT_EMBEDDING_BATCH_SIZE; } + + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE) + || in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19)) { + inputTypeTranslator = new InputTypeTranslator(in); + } else { + inputTypeTranslator = InputTypeTranslator.EMPTY_TRANSLATOR; + } } @Override @@ -305,6 +327,10 @@ public CustomResponseParser getResponseJsonParser() { return responseJsonParser; } + public InputTypeTranslator getInputTypeTranslator() { + return inputTypeTranslator; + } + public int getBatchSize() { return batchSize; } @@ -352,6 +378,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder } builder.endObject(); + inputTypeTranslator.toXContent(builder, params); + rateLimitSettings.toXContent(builder, params); builder.field(BATCH_SIZE, batchSize); @@ -390,6 +418,11 @@ public void writeTo(StreamOutput out) throws IOException { || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) { out.writeVInt(batchSize); } + + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE) + || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19)) { + inputTypeTranslator.writeTo(out); + } } @Override @@ -404,7 +437,8 @@ public boolean equals(Object o) { && Objects.equals(requestContentString, that.requestContentString) && Objects.equals(responseJsonParser, that.responseJsonParser) && Objects.equals(rateLimitSettings, that.rateLimitSettings) - && Objects.equals(batchSize, that.batchSize); + && Objects.equals(batchSize, that.batchSize) + && Objects.equals(inputTypeTranslator, that.inputTypeTranslator); } @Override @@ -417,7 +451,8 @@ public int hashCode() { requestContentString, responseJsonParser, rateLimitSettings, - batchSize + batchSize, + inputTypeTranslator ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java new file mode 100644 index 0000000000000..c6a1e552eb279 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEmptyString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class InputTypeTranslator implements ToXContentFragment, Writeable { + public static final String INPUT_TYPE_TRANSLATOR = "input_type"; + public static final String TRANSLATION = "translation"; + public static final String DEFAULT = "default"; + public static final InputTypeTranslator EMPTY_TRANSLATOR = new InputTypeTranslator(null, null); + + public static InputTypeTranslator fromMap(Map map, ValidationException validationException, String serviceName) { + if (map == null || map.isEmpty()) { + return EMPTY_TRANSLATOR; + } + + Map inputTypeTranslation = Objects.requireNonNullElse( + extractOptionalMap(map, INPUT_TYPE_TRANSLATOR, ModelConfigurations.SERVICE_SETTINGS, validationException), + new HashMap<>(Map.of()) + ); + + Map translationMap = extractOptionalMap( + inputTypeTranslation, + TRANSLATION, + INPUT_TYPE_TRANSLATOR, + validationException + ); + + var validatedTranslation = InputType.validateInputTypeTranslationValues(translationMap, validationException); + + var defaultValue = extractOptionalEmptyString(inputTypeTranslation, DEFAULT, validationException); + + throwIfNotEmptyMap(inputTypeTranslation, INPUT_TYPE_TRANSLATOR, "input_type_translator"); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new InputTypeTranslator(validatedTranslation, defaultValue); + } + + private final Map inputTypeTranslation; + private final String defaultValue; + + public InputTypeTranslator(@Nullable Map inputTypeTranslation, @Nullable String defaultValue) { + this.inputTypeTranslation = Objects.requireNonNullElse(inputTypeTranslation, Map.of()); + this.defaultValue = Objects.requireNonNullElse(defaultValue, ""); + } + + public InputTypeTranslator(StreamInput in) throws IOException { + this.inputTypeTranslation = in.readImmutableMap(keyReader -> keyReader.readEnum(InputType.class), StreamInput::readString); + this.defaultValue = in.readString(); + } + + public Map getTranslation() { + return inputTypeTranslation; + } + + public String getDefaultValue() { + return defaultValue; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + var sortedMap = new TreeMap<>(inputTypeTranslation); + + builder.startObject(INPUT_TYPE_TRANSLATOR); + { + builder.startObject(TRANSLATION); + for (var entry : sortedMap.entrySet()) { + builder.field(entry.getKey().toString(), entry.getValue()); + } + builder.endObject(); + builder.field(DEFAULT, defaultValue); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(inputTypeTranslation, StreamOutput::writeEnum, StreamOutput::writeString); + out.writeString(defaultValue); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + InputTypeTranslator that = (InputTypeTranslator) o; + return Objects.equals(inputTypeTranslation, that.inputTypeTranslation) && Objects.equals(defaultValue, that.defaultValue); + } + + @Override + public int hashCode() { + return Objects.hash(inputTypeTranslation, defaultValue); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java new file mode 100644 index 0000000000000..a1ec2a3515131 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; + +public class CompletionParameters extends RequestParameters { + + public static CompletionParameters of(ChatCompletionInput completionInput) { + return new CompletionParameters(Objects.requireNonNull(completionInput)); + } + + private CompletionParameters(ChatCompletionInput completionInput) { + super(completionInput.getInputs()); + } + + @Override + public Map jsonParameters() { + String jsonRep; + + if (inputs.isEmpty() == false) { + jsonRep = toJson(inputs.get(0), INPUT); + } else { + jsonRep = toJson("", INPUT); + } + + return Map.of(INPUT, jsonRep); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java index 1ea73f336bebb..74973adb9227a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java @@ -14,7 +14,6 @@ import org.apache.http.entity.StringEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.ValidatingSubstitutor; import org.elasticsearch.xpack.inference.external.request.HttpRequest; @@ -26,7 +25,6 @@ import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Objects; @@ -35,15 +33,13 @@ import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.URL; public class CustomRequest implements Request { - private static final String QUERY = "query"; - private static final String INPUT = "input"; private final URI uri; private final ValidatingSubstitutor jsonPlaceholderReplacer; private final ValidatingSubstitutor stringPlaceholderReplacer; private final CustomModel model; - public CustomRequest(String query, List input, CustomModel model) { + public CustomRequest(RequestParameters requestParams, CustomModel model) { this.model = Objects.requireNonNull(model); var stringOnlyParams = new HashMap(); @@ -54,11 +50,7 @@ public CustomRequest(String query, List input, CustomModel model) { addJsonStringParams(jsonParams, model.getSecretSettings().getSecretParameters()); addJsonStringParams(jsonParams, model.getTaskSettings().getParameters()); - if (query != null) { - jsonParams.put(QUERY, toJson(query, QUERY)); - } - - addInputJsonParam(jsonParams, input, model.getTaskType()); + jsonParams.putAll(requestParams.jsonParameters()); jsonPlaceholderReplacer = new ValidatingSubstitutor(jsonParams, "${", "}"); stringPlaceholderReplacer = new ValidatingSubstitutor(stringOnlyParams, "${", "}"); @@ -81,14 +73,6 @@ private static void addJsonStringParams(Map jsonStringParams, Ma } } - private static void addInputJsonParam(Map jsonParams, List input, TaskType taskType) { - if (taskType == TaskType.COMPLETION && input.isEmpty() == false) { - jsonParams.put(INPUT, toJson(input.get(0), INPUT)); - } else { - jsonParams.put(INPUT, toJson(input, INPUT)); - } - } - private URI buildUri() { var replacedUrl = stringPlaceholderReplacer.replace(model.getServiceSettings().getUrl(), URL); @@ -104,7 +88,6 @@ private URI buildUri() { } catch (URISyntaxException e) { throw new IllegalStateException(Strings.format("Failed to build URI, error: %s", e.getMessage()), e); } - } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java new file mode 100644 index 0000000000000..71eb0ed8e098d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; + +public class EmbeddingParameters extends RequestParameters { + private static final String INPUT_TYPE = "input_type"; + + public static EmbeddingParameters of(EmbeddingsInput embeddingsInput, InputTypeTranslator inputTypeTranslator) { + return new EmbeddingParameters(Objects.requireNonNull(embeddingsInput), Objects.requireNonNull(inputTypeTranslator)); + } + + private final InputType inputType; + private final InputTypeTranslator translator; + + private EmbeddingParameters(EmbeddingsInput embeddingsInput, InputTypeTranslator translator) { + super(embeddingsInput.getStringInputs()); + this.inputType = embeddingsInput.getInputType(); + this.translator = translator; + } + + @Override + protected Map taskTypeParameters() { + var additionalParameters = new HashMap(); + + if (inputType != null && translator.getTranslation().containsKey(inputType)) { + var inputTypeTranslation = translator.getTranslation().get(inputType); + + additionalParameters.put(INPUT_TYPE, toJson(inputTypeTranslation, INPUT_TYPE)); + } else { + additionalParameters.put(INPUT_TYPE, toJson(translator.getDefaultValue(), INPUT_TYPE)); + } + + return additionalParameters; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java new file mode 100644 index 0000000000000..b03acc39c49d0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; + +public abstract class RequestParameters { + + public static final String INPUT = "input"; + + protected final List inputs; + + public RequestParameters(List inputs) { + this.inputs = Objects.requireNonNull(inputs); + } + + Map jsonParameters() { + var additionalParameters = taskTypeParameters(); + var totalParameters = new HashMap<>(additionalParameters); + totalParameters.put(INPUT, toJson(inputs, INPUT)); + + return totalParameters; + } + + protected Map taskTypeParameters() { + return Map.of(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java new file mode 100644 index 0000000000000..a6503c3ef65dc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; + +public class RerankParameters extends RequestParameters { + private static final String QUERY = "query"; + + public static RerankParameters of(QueryAndDocsInputs queryAndDocsInputs) { + Objects.requireNonNull(queryAndDocsInputs); + + return new RerankParameters(queryAndDocsInputs); + } + + private final QueryAndDocsInputs queryAndDocsInputs; + + private RerankParameters(QueryAndDocsInputs queryAndDocsInputs) { + super(queryAndDocsInputs.getChunks()); + this.queryAndDocsInputs = queryAndDocsInputs; + } + + @Override + protected Map taskTypeParameters() { + var additionalParameters = new HashMap(); + additionalParameters.put(QUERY, toJson(queryAndDocsInputs.getQuery(), QUERY)); + if (queryAndDocsInputs.getTopN() != null) { + additionalParameters.put( + InferenceAction.Request.TOP_N.getPreferredName(), + toJson(queryAndDocsInputs.getTopN(), InferenceAction.Request.TOP_N.getPreferredName()) + ); + } + + if (queryAndDocsInputs.getReturnDocuments() != null) { + additionalParameters.put( + InferenceAction.Request.RETURN_DOCUMENTS.getPreferredName(), + toJson(queryAndDocsInputs.getReturnDocuments(), InferenceAction.Request.RETURN_DOCUMENTS.getPreferredName()) + ); + } + return additionalParameters; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java index 624f223c9f3e1..860e300a1c197 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java @@ -11,6 +11,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; public class ChatCompletionModelValidator implements ModelValidator { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java new file mode 100644 index 0000000000000..aee1ed3ec4ebc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; +import org.elasticsearch.rest.RestStatus; + +import java.util.List; +import java.util.Map; + +/** + * This class is slightly different from the SimpleServiceIntegrationValidator in that in sends the topN and return documents in the + * request. This is necessary because the custom service may require those template to be replaced when building the request. Otherwise, + * the request will fail to be constructed because it'll have a template that wasn't replaced. + */ +public class CustomServiceIntegrationValidator implements ServiceIntegrationValidator { + private static final List TEST_INPUT = List.of("how big"); + private static final String QUERY = "test query"; + + @Override + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + service.infer( + model, + model.getTaskType().equals(TaskType.RERANK) ? QUERY : null, + true, + 1, + TEST_INPUT, + false, + Map.of(), + InputType.INTERNAL_INGEST, + timeout, + ActionListener.wrap(r -> { + if (r != null) { + listener.onResponse(r); + } else { + listener.onFailure( + new ElasticsearchStatusException( + "Could not complete custom service inference endpoint creation as" + + " validation call to service returned null response.", + RestStatus.BAD_REQUEST + ) + ); + } + }, + e -> listener.onFailure( + new ElasticsearchStatusException( + "Could not complete custom service inference endpoint creation as validation call to service threw an exception.", + RestStatus.BAD_REQUEST, + e + ) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java index 8fdb511ab31c6..fa0e1b3e590a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java @@ -15,6 +15,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index d9eed10268527..fac9ee5e9c1c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -8,34 +8,50 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; + +import java.util.Objects; public class ModelValidatorBuilder { - public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) { - if (isElasticsearchInternalService) { + public static ModelValidator buildModelValidator(TaskType taskType, InferenceService service) { + if (service instanceof ElasticsearchInternalService) { return new ElasticsearchInternalServiceModelValidator(new SimpleServiceIntegrationValidator()); } else { - return buildModelValidatorForTaskType(taskType); + return buildModelValidatorForTaskType(taskType, service); } } - private static ModelValidator buildModelValidatorForTaskType(TaskType taskType) { + private static ModelValidator buildModelValidatorForTaskType(TaskType taskType, InferenceService service) { if (taskType == null) { throw new IllegalArgumentException("Task type can't be null"); } + ServiceIntegrationValidator validatorFromService = null; + if (service != null) { + validatorFromService = service.getServiceIntegrationValidator(taskType); + } + switch (taskType) { case TEXT_EMBEDDING -> { - return new TextEmbeddingModelValidator(new SimpleServiceIntegrationValidator()); + return new TextEmbeddingModelValidator( + Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator()) + ); } case COMPLETION -> { - return new ChatCompletionModelValidator(new SimpleServiceIntegrationValidator()); + return new ChatCompletionModelValidator( + Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator()) + ); } case CHAT_COMPLETION -> { - return new ChatCompletionModelValidator(new SimpleChatCompletionServiceIntegrationValidator()); + return new ChatCompletionModelValidator( + Objects.requireNonNullElse(validatorFromService, new SimpleChatCompletionServiceIntegrationValidator()) + ); } case SPARSE_EMBEDDING, RERANK, ANY -> { - return new SimpleModelValidator(new SimpleServiceIntegrationValidator()); + return new SimpleModelValidator(Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator())); } default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model for task type %s", taskType)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java index f9cf67172bc2a..61ec93e541924 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java index 3d592840f533b..41bdfc95dc5a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java @@ -11,6 +11,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; public class SimpleModelValidator implements ModelValidator { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java index 03ac5b95fddc5..d2cb1925ad7d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java index bff04f5af2d75..ce9df7376ebcb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java index cd51de0b57125..c6e08a8d5bdf3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java @@ -7,10 +7,13 @@ package org.elasticsearch.xpack.inference; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.elasticsearch.core.Strings.format; import static org.hamcrest.CoreMatchers.is; @@ -80,4 +83,132 @@ public void testFromRestString_ThrowsErrorForInvalidInputTypes() { assertThat(thrownException.getMessage(), is("No enum constant org.elasticsearch.inference.InputType.FOO")); } + + public void testValidateInputTypeTranslationValues() { + assertThat( + InputType.validateInputTypeTranslationValues( + Map.of( + InputType.INGEST.toString(), + "ingest_value", + InputType.SEARCH.toString(), + "search_value", + InputType.CLASSIFICATION.toString(), + "classification_value", + InputType.CLUSTERING.toString(), + "clustering_value" + ), + new ValidationException() + ), + is( + Map.of( + InputType.INGEST, + "ingest_value", + InputType.SEARCH, + "search_value", + InputType.CLASSIFICATION, + "classification_value", + InputType.CLUSTERING, + "clustering_value" + ) + ) + ); + } + + public void testValidateInputTypeTranslationValues_ReturnsEmptyMap_WhenTranslationIsNull() { + assertThat(InputType.validateInputTypeTranslationValues(null, new ValidationException()), is(Map.of())); + } + + public void testValidateInputTypeTranslationValues_ReturnsEmptyMap_WhenTranslationIsAnEmptyMap() { + assertThat(InputType.validateInputTypeTranslationValues(Map.of(), new ValidationException()), is(Map.of())); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenInputTypeIsUnspecified() { + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues( + Map.of(InputType.INGEST.toString(), "ingest_value", InputType.UNSPECIFIED.toString(), "unspecified_value"), + new ValidationException() + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Invalid input type translation for key: [unspecified], is not a valid value. Must be " + + "one of [ingest, search, classification, clustering];" + ) + ); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenInputTypeIsInternal() { + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues( + Map.of(InputType.INGEST.toString(), "ingest_value", InputType.INTERNAL_INGEST.toString(), "internal_ingest_value"), + new ValidationException() + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Invalid input type translation for key: [internal_ingest], is not a valid value. Must be " + + "one of [ingest, search, classification, clustering];" + ) + ); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIsNull() { + var translation = new HashMap(); + translation.put(InputType.INGEST.toString(), null); + + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues(translation, new ValidationException()) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [ingest] must be a String that " + + "is not null and not empty, received: [null], type: [null].;" + ) + ); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIsAnEmptyString() { + var translation = new HashMap(); + translation.put(InputType.INGEST.toString(), ""); + + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues(translation, new ValidationException()) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [ingest] must be a String that " + + "is not null and not empty, received: [], type: [String].;" + ) + ); + } + + public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIsNotAString() { + var translation = new HashMap(); + translation.put(InputType.INGEST.toString(), 1); + + var exception = expectThrows( + ValidationException.class, + () -> InputType.validateInputTypeTranslationValues(translation, new ValidationException()) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [ingest] must be a String that " + + "is not null and not empty, received: [1], type: [Integer].;" + ) + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index 5712ff363bb07..aeb09af03ebab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; @@ -460,33 +459,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { } } - public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { - var listener = new PlainActionFuture(); - - var exception = expectThrows( - ValidationException.class, - () -> service.infer( - getInvalidModel("id", "service"), - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) - ); - - assertThat( - exception.getMessage(), - is("Validation Failed: 1: Invalid input_type [ingest]. The input_type option is not supported by this service;") - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 47346fc896baa..d25db835fae04 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; @@ -157,7 +158,8 @@ public void testFromMap() { requestContentString, responseParser, new RateLimitSettings(10_000), - 11 + 11, + InputTypeTranslator.EMPTY_TRANSLATOR ) ) ); @@ -580,6 +582,56 @@ public void testXContent() throws IOException { "text_embeddings": "$.result.embeddings[*].embedding" } }, + "input_type": { + "translation": {}, + "default": "" + }, + "rate_limit": { + "requests_per_minute": 10000 + }, + "batch_size": 10 + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_WithInputTypeTranslationValues() throws IOException { + var entity = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.abc.com", + Map.of("key", "value"), + null, + "string", + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + null, + null, + new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default") + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "url": "http://www.abc.com", + "headers": { + "key": "value" + }, + "request": "string", + "response": { + "json_parser": { + "text_embeddings": "$.result.embeddings[*].embedding" + } + }, + "input_type": { + "translation": { + "ingest": "do_ingest", + "search": "do_search" + }, + "default": "a_default" + }, "rate_limit": { "requests_per_minute": 10000 }, @@ -599,7 +651,8 @@ public void testXContent_BatchSize11() throws IOException { "string", new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), null, - 11 + 11, + InputTypeTranslator.EMPTY_TRANSLATOR ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -618,6 +671,10 @@ public void testXContent_BatchSize11() throws IOException { "text_embeddings": "$.result.embeddings[*].embedding" } }, + "input_type": { + "translation": {}, + "default": "" + }, "rate_limit": { "requests_per_minute": 10000 }, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index d268b301dde8d..6ddb4ff71eeb3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -147,7 +147,7 @@ private static void assertCompletionModel(Model model) { assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(CompletionResponseParser.class)); } - private static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); return new CustomService(senderFactory, createWithEmptySettings(threadPool)); } @@ -278,7 +278,8 @@ private static CustomModel createInternalEmbeddingModel( "{\"input\":${input}}", parser, new RateLimitSettings(10_000), - batchSize + batchSize, + InputTypeTranslator.EMPTY_TRANSLATOR ), new CustomTaskSettings(Map.of("key", "test_value")), new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java new file mode 100644 index 0000000000000..0d53c11967d4c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java @@ -0,0 +1,269 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.core.Tuple.tuple; +import static org.hamcrest.Matchers.is; + +public class InputTypeTranslatorTests extends AbstractBWCWireSerializationTestCase { + public static InputTypeTranslator createRandom() { + Map translation = randomBoolean() + ? randomMap( + 0, + 5, + () -> tuple( + randomFrom(List.of(InputType.CLASSIFICATION, InputType.CLUSTERING, InputType.INGEST, InputType.SEARCH)), + randomAlphaOfLength(5) + ) + ) + : Map.of(); + return new InputTypeTranslator(translation, randomAlphaOfLength(5)); + } + + public void testFromMap() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>( + Map.of( + InputTypeTranslator.TRANSLATION, + new HashMap<>( + Map.of( + "CLASSIFICATION", + "test_value", + "CLUSTERING", + "test_value_2", + "INGEST", + "test_value_3", + "SEARCH", + "test_value_4" + ) + ), + InputTypeTranslator.DEFAULT, + "default_value" + ) + ) + ) + ); + + assertThat( + InputTypeTranslator.fromMap(settings, new ValidationException(), "name"), + is( + new InputTypeTranslator( + Map.of( + InputType.CLASSIFICATION, + "test_value", + InputType.CLUSTERING, + "test_value_2", + InputType.INGEST, + "test_value_3", + InputType.SEARCH, + "test_value_4" + ), + "default_value" + ) + ) + ); + } + + public void testFromMap_Null_EmptyMap_Returns_EmptySettings() { + assertThat(InputTypeTranslator.fromMap(null, null, null), is(InputTypeTranslator.EMPTY_TRANSLATOR)); + assertThat(InputTypeTranslator.fromMap(Map.of(), null, null), is(InputTypeTranslator.EMPTY_TRANSLATOR)); + } + + public void testFromMap_Throws_IfValueIsNotAString() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>(Map.of(InputTypeTranslator.TRANSLATION, new HashMap<>(Map.of("CLASSIFICATION", 12345)))) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> InputTypeTranslator.fromMap(settings, new ValidationException(), "name") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [CLASSIFICATION] " + + "must be a String that is not null and not empty, received: [12345], type: [Integer].;" + ) + ); + } + + public void testFromMap_Throws_IfValueIsEmptyString() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>(Map.of(InputTypeTranslator.TRANSLATION, new HashMap<>(Map.of("CLASSIFICATION", "")))) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> InputTypeTranslator.fromMap(settings, new ValidationException(), "name") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Input type translation value for key [CLASSIFICATION] " + + "must be a String that is not null and not empty, received: [], type: [String].;" + ) + ); + } + + public void testFromMap_DoesNotThrow_ForAnEmptyDefaultValue() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>( + Map.of( + InputTypeTranslator.TRANSLATION, + new HashMap<>(Map.of("CLASSIFICATION", "value")), + InputTypeTranslator.DEFAULT, + "" + ) + ) + ) + ); + + var translator = InputTypeTranslator.fromMap(settings, new ValidationException(), "name"); + + assertThat(translator, is(new InputTypeTranslator(Map.of(InputType.CLASSIFICATION, "value"), ""))); + } + + public void testFromMap_Throws_IfKeyIsInvalid() { + var settings = new HashMap( + Map.of( + InputTypeTranslator.INPUT_TYPE_TRANSLATOR, + new HashMap<>( + Map.of( + InputTypeTranslator.TRANSLATION, + new HashMap<>(Map.of("CLASSIFICATION", "test_value", "invalid_key", "another_value")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> InputTypeTranslator.fromMap(settings, new ValidationException(), "name") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Invalid input type translation for key: [invalid_key]" + + ", is not a valid value. Must be one of [ingest, search, classification, clustering];" + ) + ); + } + + public void testFromMap_DefaultsToEmptyMap_WhenField_DoesNotExist() { + var map = new HashMap(Map.of("key", new HashMap<>(Map.of("test_key", "test_value")))); + + assertThat(InputTypeTranslator.fromMap(map, new ValidationException(), "name"), is(new InputTypeTranslator(Map.of(), null))); + } + + public void testXContent() throws IOException { + var entity = new InputTypeTranslator(Map.of(InputType.CLASSIFICATION, "test_value"), "default"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "input_type": { + "translation": { + "classification": "test_value" + }, + "default": "default" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_EmptyTranslator() throws IOException { + var entity = new InputTypeTranslator(Map.of(), null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "input_type": { + "translation": {}, + "default": "" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return InputTypeTranslator::new; + } + + @Override + protected InputTypeTranslator createTestInstance() { + return createRandom(); + } + + @Override + protected InputTypeTranslator mutateInstance(InputTypeTranslator instance) throws IOException { + return randomValueOtherThan(instance, InputTypeTranslatorTests::createRandom); + } + + public static Map getTaskSettingsMap(@Nullable Map parameters) { + var map = new HashMap(); + if (parameters != null) { + map.put(CustomTaskSettings.PARAMETERS, parameters); + } + + return map; + } + + @Override + protected InputTypeTranslator mutateInstanceForVersion(InputTypeTranslator instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParametersTests.java new file mode 100644 index 0000000000000..eb2ee9c3d6d27 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParametersTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; + +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.custom.request.RequestParameters.INPUT; +import static org.hamcrest.Matchers.is; + +public class CompletionParametersTests extends ESTestCase { + + public void testJsonParameters_SingleValue() { + var parameters = CompletionParameters.of(new ChatCompletionInput(List.of("hello"))); + assertThat(parameters.jsonParameters(), is(Map.of(INPUT, "\"hello\""))); + } + + public void testJsonParameters_RetrievesFirstEntryFromList() { + var parameters = CompletionParameters.of(new ChatCompletionInput(List.of("hello", "hi"))); + assertThat(parameters.jsonParameters(), is(Map.of(INPUT, "\"hello\""))); + } + + public void testJsonParameters_EmptyList() { + var parameters = CompletionParameters.of(new ChatCompletionInput(List.of())); + assertThat(parameters.jsonParameters(), is(Map.of(INPUT, "\"\""))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 3ecacdb17cf93..d1f606daef529 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -14,14 +14,18 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings; import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; +import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator; import org.elasticsearch.xpack.inference.services.custom.QueryParameters; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; @@ -46,7 +50,8 @@ public void testCreateRequest() throws IOException { Map headers = Map.of(HttpHeaders.AUTHORIZATION, Strings.format("${api_key}")); var requestContentString = """ { - "input": ${input} + "input": ${input}, + "input_type": ${input_type} } """; @@ -62,7 +67,9 @@ public void testCreateRequest() throws IOException { new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), - new RateLimitSettings(10_000) + new RateLimitSettings(10_000), + null, + new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") ); var model = CustomModelTests.createModel( @@ -73,7 +80,13 @@ public void testCreateRequest() throws IOException { new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray()))) ); - var request = new CustomRequest(null, List.of("abc", "123"), model); + var request = new CustomRequest( + EmbeddingParameters.of( + new EmbeddingsInput(List.of("abc", "123"), null, null), + model.getServiceSettings().getInputTypeTranslator() + ), + model + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -84,18 +97,20 @@ public void testCreateRequest() throws IOException { var expectedBody = XContentHelper.stripWhitespace(""" { - "input": ["abc", "123"] + "input": ["abc", "123"], + "input_type": "default" } """); assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); } - public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { + public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOException { var inferenceId = "inferenceId"; var requestContentString = """ { - "input": ${input} + "input": ${input}, + "input_type": ${input_type} } """; @@ -115,7 +130,9 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { ), requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), - new RateLimitSettings(10_000) + new RateLimitSettings(10_000), + null, + new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") ); var model = CustomModelTests.createModel( @@ -126,7 +143,13 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray()))) ); - var request = new CustomRequest(null, List.of("abc", "123"), model); + var request = new CustomRequest( + EmbeddingParameters.of( + new EmbeddingsInput(List.of("abc", "123"), null, InputType.INGEST), + model.getServiceSettings().getInputTypeTranslator() + ), + model + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -136,6 +159,14 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { // To visually verify that this is correct, input the query parameters into here: https://www.urldecoder.org/ is("http://www.elastic.co?key=+%3C%3E%23%25%2B%7B%7D%7C%5C%5E%7E%5B%5D%60%3B%2F%3F%3A%40%3D%26%24&key=%CE%A3+%F0%9F%98%80") ); + + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"], + "input_type": "value" + } + """); + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); } public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws IOException { @@ -173,7 +204,13 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray()))) ); - var request = new CustomRequest(null, List.of("abc", "123"), model); + var request = new CustomRequest( + EmbeddingParameters.of( + new EmbeddingsInput(List.of("abc", "123"), null, InputType.SEARCH), + model.getServiceSettings().getInputTypeTranslator() + ), + model + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -220,7 +257,7 @@ public void testCreateRequest_HandlesQuery() throws IOException { new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray()))) ); - var request = new CustomRequest("query string", List.of("abc", "123"), model); + var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -236,6 +273,56 @@ public void testCreateRequest_HandlesQuery() throws IOException { assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); } + public void testCreateRequest_HandlesQuery_WithReturnDocsAndTopN() throws IOException { + var inferenceId = "inference_id"; + var requestContentString = """ + { + "input": ${input}, + "query": ${query}, + "return_documents": ${return_documents}, + "top_n": ${top_n} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.elastic.co", + null, + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of()), + new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray()))) + ); + + var request = new CustomRequest( + RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"), false, 2, false)), + model + ); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"], + "query": "query string", + "return_documents": false, + "top_n": 2 + } + """); + + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); + } + public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IOException { var inferenceId = "inference_id"; var requestContentString = """ @@ -262,7 +349,7 @@ public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IO new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray()))) ); - var request = new CustomRequest(null, List.of("abc", "123"), model); + var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model); var exception = expectThrows(IllegalStateException.class, request::createHttpRequest); assertThat( exception.getMessage(), @@ -299,7 +386,10 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() { new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray()))) ); - var exception = expectThrows(IllegalStateException.class, () -> new CustomRequest(null, List.of("abc", "123"), model)); + var exception = expectThrows( + IllegalStateException.class, + () -> new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model) + ); assertThat(exception.getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^")); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java new file mode 100644 index 0000000000000..5231802ef4b92 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator; + +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class EmbeddingParametersTests extends ESTestCase { + + public void testTaskTypeParameters_UsesDefaultValue() { + var parameters = EmbeddingParameters.of( + new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new InputTypeTranslator(Map.of(), "default") + ); + + assertThat(parameters.taskTypeParameters(), is(Map.of("input_type", "\"default\""))); + } + + public void testTaskTypeParameters_UsesMappedValue() { + var parameters = EmbeddingParameters.of( + new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new InputTypeTranslator(Map.of(InputType.INGEST, "ingest_value"), "default") + ); + + assertThat(parameters.taskTypeParameters(), is(Map.of("input_type", "\"ingest_value\""))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParametersTests.java new file mode 100644 index 0000000000000..26a5b9dd5b501 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParametersTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; + +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class RerankParametersTests extends ESTestCase { + + public void testTaskTypeParameters() { + var queryAndDocsInputs = new QueryAndDocsInputs("query_value", List.of("doc1", "doc2"), true, 5, false); + var parameters = RerankParameters.of(queryAndDocsInputs); + + assertThat(parameters.taskTypeParameters(), is(Map.of("query", "\"query_value\"", "top_n", "5", "return_documents", "true"))); + } + + public void testTaskTypeParameters_WithoutOptionalFields() { + var queryAndDocsInputs = new QueryAndDocsInputs("query_value", List.of("doc1", "doc2")); + var parameters = RerankParameters.of(queryAndDocsInputs); + + assertThat(parameters.taskTypeParameters(), is(Map.of("query", "\"query_value\""))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java index c5dabad7903bc..608fdb4d314c3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -17,8 +17,14 @@ import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters; import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; +import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters; +import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -53,10 +59,13 @@ public void testFromTextEmbeddingResponse() throws IOException { } """; + var model = CustomModelTests.getTestModel( + TaskType.TEXT_EMBEDDING, + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding") + ); var request = new CustomRequest( - null, - List.of("abc"), - CustomModelTests.getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")) + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + model ); InferenceServiceResults results = CustomResponseEntity.fromResponse( request, @@ -98,17 +107,18 @@ public void testFromSparseEmbeddingResponse() throws IOException { } """; - var request = new CustomRequest( - null, - List.of("abc"), - CustomModelTests.getTestModel( - TaskType.SPARSE_EMBEDDING, - new SparseEmbeddingResponseParser( - "$.result.sparse_embeddings[*].embedding[*].tokenId", - "$.result.sparse_embeddings[*].embedding[*].weight" - ) + var model = CustomModelTests.getTestModel( + TaskType.SPARSE_EMBEDDING, + new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].tokenId", + "$.result.sparse_embeddings[*].embedding[*].weight" ) ); + var request = new CustomRequest( + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + model + + ); InferenceServiceResults results = CustomResponseEntity.fromResponse( request, @@ -152,14 +162,11 @@ public void testFromRerankResponse() throws IOException { } """; - var request = new CustomRequest( - null, - List.of("abc"), - CustomModelTests.getTestModel( - TaskType.RERANK, - new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null) - ) + var model = CustomModelTests.getTestModel( + TaskType.RERANK, + new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null) ); + var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query", List.of("doc1", "doc2"))), model); InferenceServiceResults results = CustomResponseEntity.fromResponse( request, @@ -193,8 +200,7 @@ public void testFromCompletionResponse() throws IOException { """; var request = new CustomRequest( - null, - List.of("abc"), + CompletionParameters.of(new ChatCompletionInput(List.of("abc"))), CustomModelTests.getTestModel(TaskType.COMPLETION, new CompletionResponseParser("$.result.text")) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java index bd52bdd52ab3e..03214b770bdc7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.test.ESTestCase; import org.junit.Before; import org.mockito.Mock; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java index 19ea0bedaaea5..e26ffa137c047 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java @@ -7,21 +7,88 @@ package org.elasticsearch.xpack.inference.services.validation; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceTests; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.junit.After; +import org.junit.Before; +import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.isA; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; public class ModelValidatorBuilderTests extends ESTestCase { + + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + clientManager.close(); + terminate(threadPool); + } + + public void testCustomServiceValidator() { + var service = CustomServiceTests.createService(threadPool, clientManager); + var validator = ModelValidatorBuilder.buildModelValidator(TaskType.RERANK, service); + var mockService = mock(InferenceService.class); + validator.validate( + mockService, + CustomModelTests.getTestModel(TaskType.RERANK, new RerankResponseParser("score")), + null, + ActionListener.noop() + ); + + verify(mockService, times(1)).infer( + any(), + eq("test query"), + eq(true), + eq(1), + eq(List.of("how big")), + eq(false), + eq(Map.of()), + eq(InputType.INTERNAL_INGEST), + any(), + any() + ); + verifyNoMoreInteractions(mockService); + } + public void testBuildModelValidator_NullTaskType() { - assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null, false); }); + assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null, null); }); } public void testBuildModelValidator_ValidTaskType() { taskTypeToModelValidatorClassMap().forEach((taskType, modelValidatorClass) -> { - assertThat(ModelValidatorBuilder.buildModelValidator(taskType, false), isA(modelValidatorClass)); + assertThat(ModelValidatorBuilder.buildModelValidator(taskType, null), isA(modelValidatorClass)); }); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java index 418584d25d085..a793d5ec088b7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.junit.Before; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java index 55a02ebab082b..45726f0789667 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;