Skip to content

Commit 41f9bce

Browse files
Refactoring transport action tests to test unified validation code
1 parent 99d202f commit 41f9bce

File tree

6 files changed

+504
-346
lines changed

6 files changed

+504
-346
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
3535
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
3636

37+
import java.util.function.Supplier;
3738
import java.util.stream.Collectors;
3839

3940
import static org.elasticsearch.core.Strings.format;
@@ -75,27 +76,43 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
7576

7677
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
7778
var service = serviceRegistry.getService(unparsedModel.service());
78-
if (service.isEmpty()) {
79-
var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
79+
try {
80+
validationHelper(service::isEmpty, () -> unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
81+
validationHelper(
82+
() -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false,
83+
() -> requestModelTaskTypeMismatchException(request.getTaskType(), unparsedModel.taskType())
84+
);
85+
validationHelper(
86+
() -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel),
87+
() -> createInvalidTaskTypeException(request, unparsedModel)
88+
);
89+
} catch (Exception e) {
8090
recordMetrics(unparsedModel, timer, e);
8191
listener.onFailure(e);
8292
return;
8393
}
8494

85-
if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
86-
// not the wildcard task type and not the model task type
87-
var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
88-
recordMetrics(unparsedModel, timer, e);
89-
listener.onFailure(e);
90-
return;
91-
}
92-
93-
if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) {
94-
var e = createIncompatibleTaskTypeException(request, unparsedModel);
95-
recordMetrics(unparsedModel, timer, e);
96-
listener.onFailure(e);
97-
return;
98-
}
95+
// if (service.isEmpty()) {
96+
// var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
97+
// recordMetrics(unparsedModel, timer, e);
98+
// listener.onFailure(e);
99+
// return;
100+
// }
101+
102+
// if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
103+
// // not the wildcard task type and not the model task type
104+
// var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
105+
// recordMetrics(unparsedModel, timer, e);
106+
// listener.onFailure(e);
107+
// return;
108+
// }
109+
110+
// if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) {
111+
// var e = createInvalidTaskTypeException(request, unparsedModel);
112+
// recordMetrics(unparsedModel, timer, e);
113+
// listener.onFailure(e);
114+
// return;
115+
// }
99116

100117
var model = service.get()
101118
.parsePersistedConfigWithSecrets(
@@ -117,9 +134,15 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
117134
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
118135
}
119136

137+
private static void validationHelper(Supplier<Boolean> validationFailure, Supplier<ElasticsearchStatusException> exceptionCreator) {
138+
if (validationFailure.get()) {
139+
throw exceptionCreator.get();
140+
}
141+
}
142+
120143
protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel);
121144

122-
protected abstract ElasticsearchStatusException createIncompatibleTaskTypeException(Request request, UnparsedModel unparsedModel);
145+
protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel);
123146

124147
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
125148
try {
@@ -225,7 +248,7 @@ private static ElasticsearchStatusException unknownServiceException(String servi
225248
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
226249
}
227250

228-
private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) {
251+
private static ElasticsearchStatusException requestModelTaskTypeMismatchException(TaskType requested, TaskType expected) {
229252
return new ElasticsearchStatusException(
230253
"Incompatible task_type, the requested type [{}] does not match the model type [{}]",
231254
RestStatus.BAD_REQUEST,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request
5151
}
5252

5353
@Override
54-
protected ElasticsearchStatusException createIncompatibleTaskTypeException(
55-
InferenceAction.Request request,
56-
UnparsedModel unparsedModel
57-
) {
54+
protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, UnparsedModel unparsedModel) {
5855
return null;
5956
}
6057

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.
5353
}
5454

5555
@Override
56-
protected ElasticsearchStatusException createIncompatibleTaskTypeException(
56+
protected ElasticsearchStatusException createInvalidTaskTypeException(
5757
UnifiedCompletionAction.Request request,
5858
UnparsedModel unparsedModel
5959
) {

0 commit comments

Comments
 (0)