Skip to content

Commit ac4da6a

Browse files
Adding template validation prior to request flow
1 parent d0b74df commit ac4da6a

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutor.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ private static void ensureNoMorePlaceholdersExist(String substitutedString, Stri
5151
Matcher matcher = VARIABLE_PLACEHOLDER_PATTERN.matcher(substitutedString);
5252
if (matcher.find()) {
5353
throw new IllegalStateException(
54-
Strings.format("Found placeholder [%s] in field [%s] after replacement call", matcher.group(), field)
54+
Strings.format(
55+
"Found placeholder [%s] in field [%s] after replacement call, "
56+
+ "please check that all templates have a corresponding field definition.",
57+
matcher.group(),
58+
field
59+
)
5560
);
5661
}
5762
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.elasticsearch.xpack.inference.services.SenderService;
3737
import org.elasticsearch.xpack.inference.services.ServiceComponents;
3838
import org.elasticsearch.xpack.inference.services.ServiceUtils;
39+
import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
3940

4041
import java.util.EnumSet;
4142
import java.util.HashMap;
@@ -94,12 +95,32 @@ public void parseRequestConfig(
9495
throwIfNotEmptyMap(serviceSettingsMap, NAME);
9596
throwIfNotEmptyMap(taskSettingsMap, NAME);
9697

98+
validateConfiguration(model);
99+
97100
parsedModelListener.onResponse(model);
98101
} catch (Exception e) {
99102
parsedModelListener.onFailure(e);
100103
}
101104
}
102105

106+
/**
107+
* This does some initial validation with mock inputs to determine if any templates are missing a field to fill them.
108+
*/
109+
private static void validateConfiguration(CustomModel model) {
110+
String query = null;
111+
if (model.getTaskType() == TaskType.RERANK) {
112+
query = "test query";
113+
}
114+
115+
try {
116+
new CustomRequest(query, List.of("test input"), model).createHttpRequest();
117+
} catch (Exception e) {
118+
var validationException = new ValidationException();
119+
validationException.addValidationError(Strings.format("Failed to validate model configuration: %s", e.getMessage()));
120+
throw validationException;
121+
}
122+
}
123+
103124
@Override
104125
public InferenceServiceConfiguration getConfiguration() {
105126
return Configuration.get();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.custom;
99

1010
import org.elasticsearch.action.support.PlainActionFuture;
11+
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.common.settings.SecureString;
1213
import org.elasticsearch.core.Nullable;
1314
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
@@ -44,6 +45,7 @@
4445
import java.util.Map;
4546

4647
import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
48+
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
4749
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
4850
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
4951
import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_DOCUMENT_TEXT;
@@ -546,4 +548,47 @@ public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOEx
546548
);
547549
}
548550
}
551+
552+
public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNotFillTemplate() throws Exception {
553+
try (var service = createService(threadPool, clientManager)) {
554+
555+
var settingsMap = new HashMap<>(
556+
Map.of(
557+
CustomServiceSettings.URL,
558+
"http://www.abc.com",
559+
CustomServiceSettings.HEADERS,
560+
Map.of("key", "value"),
561+
QueryParameters.QUERY_PARAMETERS,
562+
List.of(List.of("key", "value")),
563+
CustomServiceSettings.REQUEST,
564+
"request body ${some_template}",
565+
CustomServiceSettings.RESPONSE,
566+
new HashMap<>(
567+
Map.of(
568+
CustomServiceSettings.JSON_PARSER,
569+
createResponseParserMap(TaskType.COMPLETION),
570+
CustomServiceSettings.ERROR_PARSER,
571+
new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message"))
572+
)
573+
)
574+
)
575+
);
576+
577+
var config = getRequestConfigMap(settingsMap, createTaskSettingsMap(), createSecretSettingsMap());
578+
579+
var listener = new PlainActionFuture<Model>();
580+
service.parseRequestConfig("id", TaskType.COMPLETION, config, listener);
581+
582+
var exception = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT));
583+
584+
assertThat(
585+
exception.getMessage(),
586+
is(
587+
"Validation Failed: 1: Failed to validate model configuration: Found placeholder "
588+
+ "[${some_template}] in field [request] after replacement call, please check that all "
589+
+ "templates have a corresponding field definition.;"
590+
)
591+
);
592+
}
593+
}
549594
}

0 commit comments

Comments
 (0)