Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@
org.elasticsearch.serverless.apifiltering;
exports org.elasticsearch.lucene.spatial;
exports org.elasticsearch.inference.configuration;
exports org.elasticsearch.inference.validation;
exports org.elasticsearch.monitor.metrics;
exports org.elasticsearch.plugins.internal.rewriter to org.elasticsearch.inference;
exports org.elasticsearch.lucene.util.automaton;
Expand Down
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ static TransportVersion def(int id) {
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_53);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -304,6 +305,7 @@ static TransportVersion def(int id) {
public static final TransportVersion STATE_PARAM_GET_SNAPSHOT = def(9_100_0_00);
public static final TransportVersion PROJECT_ID_IN_SNAPSHOTS_DELETIONS_AND_REPO_CLEANUP = def(9_101_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_103_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;

import java.io.Closeable;
import java.util.EnumSet;
Expand Down Expand Up @@ -248,4 +249,14 @@ default void updateModelsWithDynamicFields(List<Model> model, ActionListener<Lis
* after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use).
*/
default void onNodeStarted() {}

/**
* Get the service integration validator for the given task type.
* This allows services to provide custom validation logic.
* @param taskType The task type
* @return The service integration validator or null if the default should be used
*/
default ServiceIntegrationValidator getServiceIntegrationValidator(TaskType taskType) {
return null;
}
}
64 changes: 64 additions & 0 deletions server/src/main/java/org/elasticsearch/inference/InputType.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
package org.elasticsearch.inference;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.core.Strings.format;

Expand All @@ -29,6 +33,13 @@ public enum InputType {
INTERNAL_SEARCH,
INTERNAL_INGEST;

private static final EnumSet<InputType> SUPPORTED_REQUEST_VALUES = EnumSet.of(
InputType.CLASSIFICATION,
InputType.CLUSTERING,
InputType.INGEST,
InputType.SEARCH
);

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
Expand Down Expand Up @@ -57,4 +68,57 @@ public static boolean isSpecified(InputType inputType) {
public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

/**
* Ensures that a map used for translating input types is valid. The keys of the map are the external representation,
* and the values correspond to the values in this class.
* Throws a {@link ValidationException} if any value is not a valid InputType.
*
* @param inputTypeTranslation the map of input type translations to validate
* @param validationException a ValidationException to which errors will be added
*/
public static Map<InputType, String> validateInputTypeTranslationValues(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this

Map<String, Object> inputTypeTranslation,
ValidationException validationException
) {
if (inputTypeTranslation == null || inputTypeTranslation.isEmpty()) {
return Map.of();
}

var translationMap = new HashMap<InputType, String>();

for (var entry : inputTypeTranslation.entrySet()) {
var key = entry.getKey();
var value = entry.getValue();

if (value instanceof String == false || Strings.isNullOrEmpty((String) value)) {
validationException.addValidationError(
Strings.format(
"Input type translation value for key [%s] must be a String that is not null and not empty, received: [%s].",
key,
value.getClass().getSimpleName()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can value be null?

Suggested change
value.getClass().getSimpleName()
value == null ? "null" : value.getClass().getSimpleName()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, thank you.

)
);

throw validationException;
}

try {
var inputTypeKey = InputType.fromRestString(key);
translationMap.put(inputTypeKey, (String) value);
} catch (Exception e) {
validationException.addValidationError(
Strings.format(
"Invalid input type translation for key: [%s], is not a valid value. Must be one of %s",
key,
EnumSet.of(InputType.CLASSIFICATION, InputType.CLUSTERING, InputType.INGEST, InputType.SEARCH)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
EnumSet.of(InputType.CLASSIFICATION, InputType.CLUSTERING, InputType.INGEST, InputType.SEARCH)
SUPPORTED_REQUEST_VALUES

)
);

throw validationException;
}
}

return translationMap;
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.xpack.inference.services.validation;
package org.elasticsearch.inference.validation;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ private void parseAndStoreModel(
if (skipValidationAndStart) {
storeModelListener.onResponse(model);
} else {
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService)
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service)
.validate(service, model, timeout, storeModelListener);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,18 @@ public static String extractRequiredString(
return requiredField;
}

public static String extractOptionalEmptyString(Map<String, Object> map, String settingName, ValidationException validationException) {
int initialValidationErrorCount = validationException.validationErrors().size();
String optionalField = ServiceUtils.removeAsType(map, settingName, String.class, validationException);

if (validationException.validationErrors().size() > initialValidationErrorCount) {
// new validation error occurred
return null;
}

return optionalField;
}

public static String extractOptionalString(
Map<String, Object> map,
String settingName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters;
import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters;
import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters;
import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseEntity;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

Expand Down Expand Up @@ -65,19 +68,16 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
String query;
List<String> input;
RequestParameters requestParameters;
if (inferenceInputs instanceof QueryAndDocsInputs) {
QueryAndDocsInputs queryAndDocsInputs = QueryAndDocsInputs.of(inferenceInputs);
query = queryAndDocsInputs.getQuery();
input = queryAndDocsInputs.getChunks();
requestParameters = RerankParameters.of(QueryAndDocsInputs.of(inferenceInputs));
} else if (inferenceInputs instanceof ChatCompletionInput chatInputs) {
query = null;
input = chatInputs.getInputs();
requestParameters = CompletionParameters.of(chatInputs);
} else if (inferenceInputs instanceof EmbeddingsInput) {
EmbeddingsInput embeddingsInput = EmbeddingsInput.of(inferenceInputs);
query = null;
input = embeddingsInput.getStringInputs();
requestParameters = EmbeddingParameters.of(
EmbeddingsInput.of(inferenceInputs),
model.getServiceSettings().getInputTypeTranslator()
);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Expand All @@ -89,7 +89,7 @@ public void execute(
}

try {
var request = new CustomRequest(query, input, model);
var request = new CustomRequest(requestParameters, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, handler, hasRequestCompletedFunction, listener));
} catch (Exception e) {
// Intentionally not logging this exception because it could contain sensitive information from the CustomRequest construction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
Expand All @@ -35,7 +36,7 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.validation.CustomServiceIntegrationValidator;

import java.util.EnumSet;
import java.util.HashMap;
Expand Down Expand Up @@ -199,7 +200,8 @@ public void doInfer(

@Override
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException);
// The custom service doesn't do any validation for the input type because if the input type is supported a default
// must be supplied within the service settings.
}

@Override
Expand Down Expand Up @@ -249,7 +251,8 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
serviceSettings.getQueryParameters(),
serviceSettings.getRequestContentString(),
serviceSettings.getResponseJsonParser(),
serviceSettings.rateLimitSettings()
serviceSettings.rateLimitSettings(),
serviceSettings.getInputTypeTranslator()
);
}

Expand All @@ -275,4 +278,13 @@ public static InferenceServiceConfiguration get() {
}
);
}

@Override
public ServiceIntegrationValidator getServiceIntegrationValidator(TaskType taskType) {
if (taskType == TaskType.RERANK) {
return new CustomServiceIntegrationValidator();
}

return null;
}
}
Loading