55
66package org .opensearch .ml .task ;
77
8+ import static org .opensearch .common .xcontent .XContentParserUtils .ensureExpectedToken ;
89import static org .opensearch .ml .indices .MLIndicesHandler .ML_MODEL_INDEX ;
910import static org .opensearch .ml .permission .AccessController .checkUserPermissions ;
1011import static org .opensearch .ml .permission .AccessController .getUserContext ;
1718
1819import java .time .Instant ;
1920import java .util .Base64 ;
20- import java .util .Map ;
2121import java .util .UUID ;
2222
2323import lombok .extern .log4j .Log4j2 ;
3232import org .opensearch .client .Client ;
3333import org .opensearch .cluster .service .ClusterService ;
3434import org .opensearch .common .util .concurrent .ThreadContext ;
35+ import org .opensearch .common .xcontent .LoggingDeprecationHandler ;
36+ import org .opensearch .common .xcontent .NamedXContentRegistry ;
37+ import org .opensearch .common .xcontent .XContentParser ;
38+ import org .opensearch .common .xcontent .XContentType ;
3539import org .opensearch .commons .authuser .User ;
3640import org .opensearch .ml .common .MLModel ;
3741import org .opensearch .ml .common .MLTask ;
5357import org .opensearch .ml .stats .ActionName ;
5458import org .opensearch .ml .stats .MLStats ;
5559import org .opensearch .threadpool .ThreadPool ;
56- import org .opensearch .transport .TransportService ;
60+ import org .opensearch .transport .TransportResponseHandler ;
5761
5862/**
5963 * MLPredictTaskRunner is responsible for running predict tasks.
@@ -64,6 +68,7 @@ public class MLPredictTaskRunner extends MLTaskRunner<MLPredictionTaskRequest, M
6468 private final ClusterService clusterService ;
6569 private final Client client ;
6670 private final MLInputDatasetHandler mlInputDatasetHandler ;
71+ private final NamedXContentRegistry xContentRegistry ;
6772
6873 public MLPredictTaskRunner (
6974 ThreadPool threadPool ,
@@ -73,42 +78,34 @@ public MLPredictTaskRunner(
7378 MLStats mlStats ,
7479 MLInputDatasetHandler mlInputDatasetHandler ,
7580 MLTaskDispatcher mlTaskDispatcher ,
76- MLCircuitBreakerService mlCircuitBreakerService
81+ MLCircuitBreakerService mlCircuitBreakerService ,
82+ NamedXContentRegistry xContentRegistry
7783 ) {
78- super (mlTaskManager , mlStats , mlTaskDispatcher , mlCircuitBreakerService );
84+ super (mlTaskManager , mlStats , mlTaskDispatcher , mlCircuitBreakerService , clusterService );
7985 this .threadPool = threadPool ;
8086 this .clusterService = clusterService ;
8187 this .client = client ;
8288 this .mlInputDatasetHandler = mlInputDatasetHandler ;
89+ this .xContentRegistry = xContentRegistry ;
8390 }
8491
8592 @ Override
86- public void executeTask (MLPredictionTaskRequest request , TransportService transportService , ActionListener <MLTaskResponse > listener ) {
87- mlTaskDispatcher .dispatchTask (ActionListener .wrap (node -> {
88- if (clusterService .localNode ().getId ().equals (node .getId ())) {
89- // Execute prediction task locally
90- log .info ("execute ML prediction request {} locally on node {}" , request .toString (), node .getId ());
91- startPredictionTask (request , listener );
92- } else {
93- // Execute batch task remotely
94- log .info ("execute ML prediction request {} remotely on node {}" , request .toString (), node .getId ());
95- transportService
96- .sendRequest (
97- node ,
98- MLPredictionTaskAction .NAME ,
99- request ,
100- new ActionListenerResponseHandler <>(listener , MLTaskResponse ::new )
101- );
102- }
103- }, e -> listener .onFailure (e )));
93+ protected String getTransportActionName () {
94+ return MLPredictionTaskAction .NAME ;
95+ }
96+
97+ @ Override
98+ protected TransportResponseHandler <MLTaskResponse > getResponseHandler (ActionListener <MLTaskResponse > listener ) {
99+ return new ActionListenerResponseHandler <>(listener , MLTaskResponse ::new );
104100 }
105101
106102 /**
107103 * Start prediction task
108104 * @param request MLPredictionTaskRequest
109105 * @param listener Action listener
110106 */
111- public void startPredictionTask (MLPredictionTaskRequest request , ActionListener <MLTaskResponse > listener ) {
107+ @ Override
108+ protected void executeTask (MLPredictionTaskRequest request , ActionListener <MLTaskResponse > listener ) {
112109 MLInputDataType inputDataType = request .getMlInput ().getInputDataset ().getInputDataType ();
113110 Instant now = Instant .now ();
114111 MLTask mlTask = MLTask
@@ -166,36 +163,49 @@ private void predict(
166163 internalListener .onFailure (new ResourceNotFoundException ("No model found, please check the modelId." ));
167164 return ;
168165 }
169- Map <String , Object > source = r .getSourceAsMap ();
170- User requestUser = getUserContext (client );
171- User resourceUser = User .parse ((String ) source .get (USER ));
172- if (!checkUserPermissions (requestUser , resourceUser , request .getModelId ())) {
173- // The backend roles of request user and resource user doesn't have intersection
174- OpenSearchException e = new OpenSearchException (
175- "User: " + requestUser .getName () + " does not have permissions to run predict by model: " + request .getModelId ()
176- );
177- handlePredictFailure (mlTask , internalListener , e , false );
178- return ;
179- }
166+ try (
167+ XContentParser xContentParser = XContentType .JSON
168+ .xContent ()
169+ .createParser (xContentRegistry , LoggingDeprecationHandler .INSTANCE , r .getSourceAsString ())
170+ ) {
171+ ensureExpectedToken (XContentParser .Token .START_OBJECT , xContentParser .nextToken (), xContentParser );
172+ MLModel mlModel = MLModel .parse (xContentParser );
173+ User resourceUser = mlModel .getUser ();
174+ User requestUser = getUserContext (client );
175+ if (!checkUserPermissions (requestUser , resourceUser , request .getModelId ())) {
176+ // The backend roles of request user and resource user doesn't have intersection
177+ OpenSearchException e = new OpenSearchException (
178+ "User: "
179+ + requestUser .getName ()
180+ + " does not have permissions to run predict by model: "
181+ + request .getModelId ()
182+ );
183+ handlePredictFailure (mlTask , internalListener , e , false );
184+ return ;
185+ }
186+ Model model = new Model ();
187+ model .setName (mlModel .getName ());
188+ model .setVersion (mlModel .getVersion ());
189+ byte [] decoded = Base64 .getDecoder ().decode (mlModel .getContent ());
190+ model .setContent (decoded );
191+
192+ // run predict
193+ mlTaskManager .updateTaskState (mlTask .getTaskId (), MLTaskState .RUNNING , mlTask .isAsync ());
194+ MLOutput output = MLEngine
195+ .predict (mlInput .toBuilder ().inputDataset (new DataFrameInputDataset (inputDataFrame )).build (), model );
196+ if (output instanceof MLPredictionOutput ) {
197+ ((MLPredictionOutput ) output ).setStatus (MLTaskState .COMPLETED .name ());
198+ }
180199
181- Model model = new Model ();
182- model .setName ((String ) source .get (MLModel .MODEL_NAME ));
183- model .setVersion ((Integer ) source .get (MLModel .MODEL_VERSION ));
184- byte [] decoded = Base64 .getDecoder ().decode ((String ) source .get (MLModel .MODEL_CONTENT ));
185- model .setContent (decoded );
186-
187- // run predict
188- mlTaskManager .updateTaskState (mlTask .getTaskId (), MLTaskState .RUNNING , mlTask .isAsync ());
189- MLOutput output = MLEngine
190- .predict (mlInput .toBuilder ().inputDataset (new DataFrameInputDataset (inputDataFrame )).build (), model );
191- if (output instanceof MLPredictionOutput ) {
192- ((MLPredictionOutput ) output ).setStatus (MLTaskState .COMPLETED .name ());
200+ // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
201+ handleAsyncMLTaskComplete (mlTask );
202+ MLTaskResponse response = MLTaskResponse .builder ().output (output ).build ();
203+ internalListener .onResponse (response );
204+ } catch (Exception e ) {
205+ log .error ("Failed to predict model " + request .getModelId (), e );
206+ internalListener .onFailure (e );
193207 }
194208
195- // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
196- handleAsyncMLTaskComplete (mlTask );
197- MLTaskResponse response = MLTaskResponse .builder ().output (output ).build ();
198- internalListener .onResponse (response );
199209 }, e -> {
200210 log .error ("Failed to predict " + mlInput .getAlgorithm () + ", modelId: " + mlTask .getModelId (), e );
201211 handlePredictFailure (mlTask , internalListener , e , true );
0 commit comments