3434import org .elasticsearch .xpack .inference .telemetry .InferenceStats ;
3535import org .elasticsearch .xpack .inference .telemetry .InferenceTimer ;
3636
37+ import java .util .function .Supplier ;
3738import java .util .stream .Collectors ;
3839
3940import static org .elasticsearch .core .Strings .format ;
@@ -75,27 +76,43 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
7576
7677 var getModelListener = ActionListener .wrap ((UnparsedModel unparsedModel ) -> {
7778 var service = serviceRegistry .getService (unparsedModel .service ());
78- if (service .isEmpty ()) {
79- var e = unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ());
79+ try {
80+ validationHelper (service ::isEmpty , () -> unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ()));
81+ validationHelper (
82+ () -> request .getTaskType ().isAnyOrSame (unparsedModel .taskType ()) == false ,
83+ () -> requestModelTaskTypeMismatchException (request .getTaskType (), unparsedModel .taskType ())
84+ );
85+ validationHelper (
86+ () -> isInvalidTaskTypeForInferenceEndpoint (request , unparsedModel ),
87+ () -> createInvalidTaskTypeException (request , unparsedModel )
88+ );
89+ } catch (Exception e ) {
8090 recordMetrics (unparsedModel , timer , e );
8191 listener .onFailure (e );
8292 return ;
8393 }
8494
85- if (request .getTaskType ().isAnyOrSame (unparsedModel .taskType ()) == false ) {
86- // not the wildcard task type and not the model task type
87- var e = incompatibleTaskTypeException (request .getTaskType (), unparsedModel .taskType ());
88- recordMetrics (unparsedModel , timer , e );
89- listener .onFailure (e );
90- return ;
91- }
92-
93- if (isInvalidTaskTypeForInferenceEndpoint (request , unparsedModel )) {
94- var e = createIncompatibleTaskTypeException (request , unparsedModel );
95- recordMetrics (unparsedModel , timer , e );
96- listener .onFailure (e );
97- return ;
98- }
95+ // if (service.isEmpty()) {
96+ // var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
97+ // recordMetrics(unparsedModel, timer, e);
98+ // listener.onFailure(e);
99+ // return;
100+ // }
101+
102+ // if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
103+ // // not the wildcard task type and not the model task type
104+ // var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
105+ // recordMetrics(unparsedModel, timer, e);
106+ // listener.onFailure(e);
107+ // return;
108+ // }
109+
110+ // if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) {
111+ // var e = createInvalidTaskTypeException(request, unparsedModel);
112+ // recordMetrics(unparsedModel, timer, e);
113+ // listener.onFailure(e);
114+ // return;
115+ // }
99116
100117 var model = service .get ()
101118 .parsePersistedConfigWithSecrets (
@@ -117,9 +134,15 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
117134 modelRegistry .getModelWithSecrets (request .getInferenceEntityId (), getModelListener );
118135 }
119136
137+ private static void validationHelper (Supplier <Boolean > validationFailure , Supplier <ElasticsearchStatusException > exceptionCreator ) {
138+ if (validationFailure .get ()) {
139+ throw exceptionCreator .get ();
140+ }
141+ }
142+
120143 protected abstract boolean isInvalidTaskTypeForInferenceEndpoint (Request request , UnparsedModel unparsedModel );
121144
122- protected abstract ElasticsearchStatusException createIncompatibleTaskTypeException (Request request , UnparsedModel unparsedModel );
145+ protected abstract ElasticsearchStatusException createInvalidTaskTypeException (Request request , UnparsedModel unparsedModel );
123146
124147 private void recordMetrics (UnparsedModel model , InferenceTimer timer , @ Nullable Throwable t ) {
125148 try {
@@ -225,7 +248,7 @@ private static ElasticsearchStatusException unknownServiceException(String servi
225248 return new ElasticsearchStatusException ("Unknown service [{}] for model [{}]. " , RestStatus .BAD_REQUEST , service , inferenceId );
226249 }
227250
228- private static ElasticsearchStatusException incompatibleTaskTypeException (TaskType requested , TaskType expected ) {
251+ private static ElasticsearchStatusException requestModelTaskTypeMismatchException (TaskType requested , TaskType expected ) {
229252 return new ElasticsearchStatusException (
230253 "Incompatible task_type, the requested type [{}] does not match the model type [{}]" ,
231254 RestStatus .BAD_REQUEST ,
0 commit comments