4545import org .opensearch .common .settings .Setting ;
4646import org .opensearch .common .settings .Settings ;
4747import org .opensearch .common .util .concurrent .ThreadContext ;
48+ import org .opensearch .commons .authuser .User ;
4849import org .opensearch .core .action .ActionListener ;
4950import org .opensearch .core .rest .RestStatus ;
5051import org .opensearch .core .xcontent .NamedXContentRegistry ;
5960import org .opensearch .ml .common .connector .ConnectorAction .ActionType ;
6061import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
6162import org .opensearch .ml .common .exception .MLResourceNotFoundException ;
63+ import org .opensearch .ml .common .exception .MLValidationException ;
6264import org .opensearch .ml .common .input .MLInput ;
6365import org .opensearch .ml .common .output .model .ModelTensorOutput ;
6466import org .opensearch .ml .common .output .model .ModelTensors ;
7173import org .opensearch .ml .engine .algorithms .remote .RemoteConnectorExecutor ;
7274import org .opensearch .ml .engine .encryptor .EncryptorImpl ;
7375import org .opensearch .ml .helper .ConnectorAccessControlHelper ;
76+ import org .opensearch .ml .helper .ModelAccessControlHelper ;
7477import org .opensearch .ml .model .MLModelManager ;
7578import org .opensearch .ml .settings .MLFeatureEnabledSetting ;
7679import org .opensearch .ml .task .MLTaskManager ;
80+ import org .opensearch .ml .utils .RestActionUtils ;
7781import org .opensearch .script .ScriptService ;
7882import org .opensearch .tasks .Task ;
7983import org .opensearch .transport .TransportService ;
@@ -90,6 +94,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
9094 ScriptService scriptService ;
9195
9296 ConnectorAccessControlHelper connectorAccessControlHelper ;
97+ ModelAccessControlHelper modelAccessControlHelper ;
9398 EncryptorImpl encryptor ;
9499 MLModelManager mlModelManager ;
95100
@@ -111,6 +116,7 @@ public GetTaskTransportAction(
111116 ClusterService clusterService ,
112117 ScriptService scriptService ,
113118 ConnectorAccessControlHelper connectorAccessControlHelper ,
119+ ModelAccessControlHelper modelAccessControlHelper ,
114120 EncryptorImpl encryptor ,
115121 MLTaskManager mlTaskManager ,
116122 MLModelManager mlModelManager ,
@@ -123,6 +129,7 @@ public GetTaskTransportAction(
123129 this .clusterService = clusterService ;
124130 this .scriptService = scriptService ;
125131 this .connectorAccessControlHelper = connectorAccessControlHelper ;
132+ this .modelAccessControlHelper = modelAccessControlHelper ;
126133 this .encryptor = encryptor ;
127134 this .mlTaskManager = mlTaskManager ;
128135 this .mlModelManager = mlModelManager ;
@@ -238,26 +245,40 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi
238245 RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet (parameters , ActionType .BATCH_PREDICT_STATUS );
239246 MLInput mlInput = MLInput .builder ().algorithm (FunctionName .REMOTE ).inputDataset (inferenceInputDataSet ).build ();
240247 String modelId = mlTask .getModelId ();
248+ User user = RestActionUtils .getUserContext (client );
241249
242250 try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
243251 ActionListener <MLModel > getModelListener = ActionListener .wrap (model -> {
244- if (model .getConnector () != null ) {
245- Connector connector = model .getConnector ();
246- executeConnector (connector , mlInput , taskId , mlTask , remoteJob , actionListener );
247- } else if (clusterService .state ().metadata ().hasIndex (ML_CONNECTOR_INDEX )) {
248- ActionListener <Connector > listener = ActionListener .wrap (connector -> {
249- executeConnector (connector , mlInput , taskId , mlTask , remoteJob , actionListener );
250- }, e -> {
251- log .error ("Failed to get connector " + model .getConnectorId (), e );
252- actionListener .onFailure (e );
253- });
254- try (ThreadContext .StoredContext threadContext = client .threadPool ().getThreadContext ().stashContext ()) {
255- connectorAccessControlHelper
256- .getConnector (client , model .getConnectorId (), ActionListener .runBefore (listener , threadContext ::restore ));
252+ modelAccessControlHelper .validateModelGroupAccess (user , model .getModelGroupId (), client , ActionListener .wrap (access -> {
253+ if (!access ) {
254+ actionListener .onFailure (new MLValidationException ("You don't have permission to access this batch job" ));
255+ } else {
256+ if (model .getConnector () != null ) {
257+ Connector connector = model .getConnector ();
258+ executeConnector (connector , mlInput , taskId , mlTask , remoteJob , actionListener );
259+ } else if (clusterService .state ().metadata ().hasIndex (ML_CONNECTOR_INDEX )) {
260+ ActionListener <Connector > listener = ActionListener .wrap (connector -> {
261+ executeConnector (connector , mlInput , taskId , mlTask , remoteJob , actionListener );
262+ }, e -> {
263+ log .error ("Failed to get connector " + model .getConnectorId (), e );
264+ actionListener .onFailure (e );
265+ });
266+ try (ThreadContext .StoredContext threadContext = client .threadPool ().getThreadContext ().stashContext ()) {
267+ connectorAccessControlHelper
268+ .getConnector (
269+ client ,
270+ model .getConnectorId (),
271+ ActionListener .runBefore (listener , threadContext ::restore )
272+ );
273+ }
274+ } else {
275+ actionListener .onFailure (new ResourceNotFoundException ("Can't find connector " + model .getConnectorId ()));
276+ }
257277 }
258- } else {
259- actionListener .onFailure (new ResourceNotFoundException ("Can't find connector " + model .getConnectorId ()));
260- }
278+ }, e -> {
279+ log .error ("Failed to validate Access for Model Group " + model .getModelGroupId (), e );
280+ actionListener .onFailure (e );
281+ }));
261282 }, e -> {
262283 log .error ("Failed to retrieve the ML model for the given task ID" , e );
263284 actionListener
@@ -280,26 +301,20 @@ private void executeConnector(
280301 Map <String , Object > remoteJob ,
281302 ActionListener <MLTaskGetResponse > actionListener
282303 ) {
283- if (connectorAccessControlHelper .validateConnectorAccess (client , connector )) {
284- Optional <ConnectorAction > batchPredictStatusAction = connector .findAction (BATCH_PREDICT_STATUS .name ());
285- if (!batchPredictStatusAction .isPresent () || batchPredictStatusAction .get ().getRequestBody () == null ) {
286- ConnectorAction connectorAction = ConnectorUtils .createConnectorAction (connector , BATCH_PREDICT_STATUS );
287- connector .addAction (connectorAction );
288- }
289- connector .decrypt (BATCH_PREDICT_STATUS .name (), (credential ) -> encryptor .decrypt (credential ));
290- RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
291- .initInstance (connector .getProtocol (), connector , Connector .class );
292- connectorExecutor .setScriptService (scriptService );
293- connectorExecutor .setClusterService (clusterService );
294- connectorExecutor .setClient (client );
295- connectorExecutor .setXContentRegistry (xContentRegistry );
296- connectorExecutor .executeAction (BATCH_PREDICT_STATUS .name (), mlInput , ActionListener .wrap (taskResponse -> {
297- processTaskResponse (mlTask , taskId , taskResponse , remoteJob , actionListener );
298- }, e -> { actionListener .onFailure (e ); }));
299- } else {
300- actionListener
301- .onFailure (new OpenSearchStatusException ("You don't have permission to access this connector" , RestStatus .FORBIDDEN ));
304+ Optional <ConnectorAction > batchPredictStatusAction = connector .findAction (BATCH_PREDICT_STATUS .name ());
305+ if (!batchPredictStatusAction .isPresent () || batchPredictStatusAction .get ().getRequestBody () == null ) {
306+ ConnectorAction connectorAction = ConnectorUtils .createConnectorAction (connector , BATCH_PREDICT_STATUS );
307+ connector .addAction (connectorAction );
302308 }
309+ connector .decrypt (BATCH_PREDICT_STATUS .name (), (credential ) -> encryptor .decrypt (credential ));
310+ RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader .initInstance (connector .getProtocol (), connector , Connector .class );
311+ connectorExecutor .setScriptService (scriptService );
312+ connectorExecutor .setClusterService (clusterService );
313+ connectorExecutor .setClient (client );
314+ connectorExecutor .setXContentRegistry (xContentRegistry );
315+ connectorExecutor .executeAction (BATCH_PREDICT_STATUS .name (), mlInput , ActionListener .wrap (taskResponse -> {
316+ processTaskResponse (mlTask , taskId , taskResponse , remoteJob , actionListener );
317+ }, e -> { actionListener .onFailure (e ); }));
303318 }
304319
305320 protected void processTaskResponse (
0 commit comments