Skip to content

Commit d9b34d4

Browse files
[ML] Custom service add support for input_type, top_n, and return_documents (#129441)
* Making progress on different request parameters * Working tests * Adding custom service validator for rerank * Fixing embedding bug * Adding transport version check * Fixing tests * Fixing license header * Fixing writeTo * Moving file and removing commented code * Fixing test * Fixing tests * Refactoring and tests * Fixing test
1 parent a230165 commit d9b34d4

38 files changed

+1372
-125
lines changed

server/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@
475475
org.elasticsearch.serverless.apifiltering;
476476
exports org.elasticsearch.lucene.spatial;
477477
exports org.elasticsearch.inference.configuration;
478+
exports org.elasticsearch.inference.validation;
478479
exports org.elasticsearch.monitor.metrics;
479480
exports org.elasticsearch.plugins.internal.rewriter to org.elasticsearch.inference;
480481
exports org.elasticsearch.lucene.util.automaton;

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ static TransportVersion def(int id) {
200200
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
201201
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_53);
202202
public static final TransportVersion STREAMS_LOGS_SUPPORT_8_19 = def(8_841_0_54);
203+
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_55);
203204

204205
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
205206
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
@@ -308,6 +309,7 @@ static TransportVersion def(int id) {
308309
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00);
309310
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00);
310311
public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00);
312+
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_105_0_00);
311313

312314
/*
313315
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.client.internal.Client;
1515
import org.elasticsearch.core.Nullable;
1616
import org.elasticsearch.core.TimeValue;
17+
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
1718

1819
import java.io.Closeable;
1920
import java.util.EnumSet;
@@ -248,4 +249,14 @@ default void updateModelsWithDynamicFields(List<Model> model, ActionListener<Lis
248249
* after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use).
249250
*/
250251
default void onNodeStarted() {}
252+
253+
/**
254+
* Get the service integration validator for the given task type.
255+
* This allows services to provide custom validation logic.
256+
* @param taskType The task type
257+
* @return The service integration validator or null if the default should be used
258+
*/
259+
default ServiceIntegrationValidator getServiceIntegrationValidator(TaskType taskType) {
260+
return null;
261+
}
251262
}

server/src/main/java/org/elasticsearch/inference/InputType.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
package org.elasticsearch.inference;
1111

1212
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.common.ValidationException;
1314

15+
import java.util.EnumSet;
16+
import java.util.HashMap;
1417
import java.util.Locale;
18+
import java.util.Map;
1519

1620
import static org.elasticsearch.core.Strings.format;
1721

@@ -29,6 +33,13 @@ public enum InputType {
2933
INTERNAL_SEARCH,
3034
INTERNAL_INGEST;
3135

36+
private static final EnumSet<InputType> SUPPORTED_REQUEST_VALUES = EnumSet.of(
37+
InputType.CLASSIFICATION,
38+
InputType.CLUSTERING,
39+
InputType.INGEST,
40+
InputType.SEARCH
41+
);
42+
3243
@Override
3344
public String toString() {
3445
return name().toLowerCase(Locale.ROOT);
@@ -57,4 +68,70 @@ public static boolean isSpecified(InputType inputType) {
5768
public static String invalidInputTypeMessage(InputType inputType) {
5869
return Strings.format("received invalid input type value [%s]", inputType.toString());
5970
}
71+
72+
/**
73+
* Ensures that a map used for translating input types is valid. The keys of the map are the external representation,
74+
* and the values correspond to the values in this class.
75+
* Throws a {@link ValidationException} if any value is not a valid InputType.
76+
*
77+
* @param inputTypeTranslation the map of input type translations to validate
78+
* @param validationException a ValidationException to which errors will be added
79+
*/
80+
public static Map<InputType, String> validateInputTypeTranslationValues(
81+
Map<String, Object> inputTypeTranslation,
82+
ValidationException validationException
83+
) {
84+
if (inputTypeTranslation == null || inputTypeTranslation.isEmpty()) {
85+
return Map.of();
86+
}
87+
88+
var translationMap = new HashMap<InputType, String>();
89+
90+
for (var entry : inputTypeTranslation.entrySet()) {
91+
var key = entry.getKey();
92+
var value = entry.getValue();
93+
94+
if (value instanceof String == false || Strings.isNullOrEmpty((String) value)) {
95+
validationException.addValidationError(
96+
Strings.format(
97+
"Input type translation value for key [%s] must be a String that is "
98+
+ "not null and not empty, received: [%s], type: [%s].",
99+
key,
100+
value,
101+
value == null ? "null" : value.getClass().getSimpleName()
102+
)
103+
);
104+
105+
throw validationException;
106+
}
107+
108+
try {
109+
var inputTypeKey = InputType.fromStringValidateSupportedRequestValue(key);
110+
translationMap.put(inputTypeKey, (String) value);
111+
} catch (Exception e) {
112+
validationException.addValidationError(
113+
Strings.format(
114+
"Invalid input type translation for key: [%s], is not a valid value. Must be one of %s",
115+
key,
116+
SUPPORTED_REQUEST_VALUES
117+
)
118+
);
119+
120+
throw validationException;
121+
}
122+
}
123+
124+
return translationMap;
125+
}
126+
127+
private static InputType fromStringValidateSupportedRequestValue(String name) {
128+
var inputType = fromRestString(name);
129+
if (SUPPORTED_REQUEST_VALUES.contains(inputType) == false) {
130+
throw new IllegalArgumentException(
131+
format("Unrecognized input_type [%s], must be one of %s", inputType, SUPPORTED_REQUEST_VALUES)
132+
);
133+
}
134+
135+
return inputType;
136+
}
60137
}
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
/*
22
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3-
* or more contributor license agreements. Licensed under the Elastic License
4-
* 2.0; you may not use this file except in compliance with the Elastic License
5-
* 2.0.
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
68
*/
79

8-
package org.elasticsearch.xpack.inference.services.validation;
10+
package org.elasticsearch.inference.validation;
911

1012
import org.elasticsearch.action.ActionListener;
1113
import org.elasticsearch.core.TimeValue;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ private void parseAndStoreModel(
216216
if (skipValidationAndStart) {
217217
storeModelListener.onResponse(model);
218218
} else {
219-
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService)
219+
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service)
220220
.validate(service, model, timeout, storeModelListener);
221221
}
222222
});

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,18 @@ public static String extractRequiredString(
399399
return requiredField;
400400
}
401401

402+
public static String extractOptionalEmptyString(Map<String, Object> map, String settingName, ValidationException validationException) {
403+
int initialValidationErrorCount = validationException.validationErrors().size();
404+
String optionalField = ServiceUtils.removeAsType(map, settingName, String.class, validationException);
405+
406+
if (validationException.validationErrors().size() > initialValidationErrorCount) {
407+
// new validation error occurred
408+
return null;
409+
}
410+
411+
return optionalField;
412+
}
413+
402414
public static String extractOptionalString(
403415
Map<String, Object> map,
404416
String settingName,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
2424
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
2525
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
26+
import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters;
2627
import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
28+
import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters;
29+
import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters;
30+
import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters;
2731
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseEntity;
2832

29-
import java.util.List;
3033
import java.util.Objects;
3134
import java.util.function.Supplier;
3235

@@ -65,19 +68,16 @@ public void execute(
6568
Supplier<Boolean> hasRequestCompletedFunction,
6669
ActionListener<InferenceServiceResults> listener
6770
) {
68-
String query;
69-
List<String> input;
71+
RequestParameters requestParameters;
7072
if (inferenceInputs instanceof QueryAndDocsInputs) {
71-
QueryAndDocsInputs queryAndDocsInputs = QueryAndDocsInputs.of(inferenceInputs);
72-
query = queryAndDocsInputs.getQuery();
73-
input = queryAndDocsInputs.getChunks();
73+
requestParameters = RerankParameters.of(QueryAndDocsInputs.of(inferenceInputs));
7474
} else if (inferenceInputs instanceof ChatCompletionInput chatInputs) {
75-
query = null;
76-
input = chatInputs.getInputs();
75+
requestParameters = CompletionParameters.of(chatInputs);
7776
} else if (inferenceInputs instanceof EmbeddingsInput) {
78-
EmbeddingsInput embeddingsInput = EmbeddingsInput.of(inferenceInputs);
79-
query = null;
80-
input = embeddingsInput.getStringInputs();
77+
requestParameters = EmbeddingParameters.of(
78+
EmbeddingsInput.of(inferenceInputs),
79+
model.getServiceSettings().getInputTypeTranslator()
80+
);
8181
} else {
8282
listener.onFailure(
8383
new ElasticsearchStatusException(
@@ -89,7 +89,7 @@ public void execute(
8989
}
9090

9191
try {
92-
var request = new CustomRequest(query, input, model);
92+
var request = new CustomRequest(requestParameters, model);
9393
execute(new ExecutableInferenceRequest(requestSender, logger, request, handler, hasRequestCompletedFunction, listener));
9494
} catch (Exception e) {
9595
// Intentionally not logging this exception because it could contain sensitive information from the CustomRequest construction

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,26 @@
2727
import org.elasticsearch.inference.SettingsConfiguration;
2828
import org.elasticsearch.inference.SimilarityMeasure;
2929
import org.elasticsearch.inference.TaskType;
30+
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
3031
import org.elasticsearch.rest.RestStatus;
3132
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3233
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3334
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
35+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
3436
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3537
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3638
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
39+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
3740
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
3841
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
3942
import org.elasticsearch.xpack.inference.services.SenderService;
4043
import org.elasticsearch.xpack.inference.services.ServiceComponents;
41-
import org.elasticsearch.xpack.inference.services.ServiceUtils;
44+
import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters;
4245
import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
46+
import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters;
47+
import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters;
48+
import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters;
49+
import org.elasticsearch.xpack.inference.services.validation.CustomServiceIntegrationValidator;
4350

4451
import java.util.EnumSet;
4552
import java.util.HashMap;
@@ -115,20 +122,29 @@ public void parseRequestConfig(
115122
* This does some initial validation with mock inputs to determine if any templates are missing a field to fill them.
116123
*/
117124
private static void validateConfiguration(CustomModel model) {
118-
String query = null;
119-
if (model.getTaskType() == TaskType.RERANK) {
120-
query = "test query";
121-
}
122-
123125
try {
124-
new CustomRequest(query, List.of("test input"), model).createHttpRequest();
126+
new CustomRequest(createParameters(model), model).createHttpRequest();
125127
} catch (IllegalStateException e) {
126128
var validationException = new ValidationException();
127129
validationException.addValidationError(Strings.format("Failed to validate model configuration: %s", e.getMessage()));
128130
throw validationException;
129131
}
130132
}
131133

134+
private static RequestParameters createParameters(CustomModel model) {
135+
return switch (model.getTaskType()) {
136+
case RERANK -> RerankParameters.of(new QueryAndDocsInputs("test query", List.of("test input")));
137+
case COMPLETION -> CompletionParameters.of(new ChatCompletionInput(List.of("test input")));
138+
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> EmbeddingParameters.of(
139+
new EmbeddingsInput(List.of("test input"), null, null),
140+
model.getServiceSettings().getInputTypeTranslator()
141+
);
142+
default -> throw new IllegalStateException(
143+
Strings.format("Unsupported task type [%s] for custom service", model.getTaskType())
144+
);
145+
};
146+
}
147+
132148
private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
133149
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
134150
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
@@ -257,7 +273,8 @@ public void doInfer(
257273

258274
@Override
259275
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
260-
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException);
276+
// The custom service doesn't do any validation for the input type because if the input type is supported a default
277+
// must be supplied within the service settings.
261278
}
262279

263280
@Override
@@ -327,7 +344,9 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
327344
serviceSettings.getQueryParameters(),
328345
serviceSettings.getRequestContentString(),
329346
serviceSettings.getResponseJsonParser(),
330-
serviceSettings.rateLimitSettings()
347+
serviceSettings.rateLimitSettings(),
348+
serviceSettings.getBatchSize(),
349+
serviceSettings.getInputTypeTranslator()
331350
);
332351
}
333352

@@ -353,4 +372,13 @@ public static InferenceServiceConfiguration get() {
353372
}
354373
);
355374
}
375+
376+
@Override
377+
public ServiceIntegrationValidator getServiceIntegrationValidator(TaskType taskType) {
378+
if (taskType == TaskType.RERANK) {
379+
return new CustomServiceIntegrationValidator();
380+
}
381+
382+
return null;
383+
}
356384
}

0 commit comments

Comments
 (0)