Skip to content

Commit c568d72

Browse files
change to model group access for batch job task APIs (#3098) (#3102)
* change to model group access for batch job task APIs Signed-off-by: Bhavana Ramaram <[email protected]> (cherry picked from commit 6277410) Co-authored-by: Bhavana Ramaram <[email protected]>
1 parent c809580 commit c568d72

File tree

4 files changed

+139
-80
lines changed

4 files changed

+139
-80
lines changed

plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.opensearch.cluster.service.ClusterService;
2929
import org.opensearch.common.inject.Inject;
3030
import org.opensearch.common.util.concurrent.ThreadContext;
31+
import org.opensearch.commons.authuser.User;
3132
import org.opensearch.core.action.ActionListener;
3233
import org.opensearch.core.rest.RestStatus;
3334
import org.opensearch.core.xcontent.NamedXContentRegistry;
@@ -42,6 +43,7 @@
4243
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
4344
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
4445
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
46+
import org.opensearch.ml.common.exception.MLValidationException;
4547
import org.opensearch.ml.common.input.MLInput;
4648
import org.opensearch.ml.common.output.model.ModelTensorOutput;
4749
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -54,9 +56,11 @@
5456
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
5557
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
5658
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
59+
import org.opensearch.ml.helper.ModelAccessControlHelper;
5760
import org.opensearch.ml.model.MLModelManager;
5861
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
5962
import org.opensearch.ml.task.MLTaskManager;
63+
import org.opensearch.ml.utils.RestActionUtils;
6064
import org.opensearch.script.ScriptService;
6165
import org.opensearch.tasks.Task;
6266
import org.opensearch.transport.TransportService;
@@ -73,6 +77,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction<Action
7377
ScriptService scriptService;
7478

7579
ConnectorAccessControlHelper connectorAccessControlHelper;
80+
ModelAccessControlHelper modelAccessControlHelper;
7681
EncryptorImpl encryptor;
7782
MLModelManager mlModelManager;
7883

@@ -88,6 +93,7 @@ public CancelBatchJobTransportAction(
8893
ClusterService clusterService,
8994
ScriptService scriptService,
9095
ConnectorAccessControlHelper connectorAccessControlHelper,
96+
ModelAccessControlHelper modelAccessControlHelper,
9197
EncryptorImpl encryptor,
9298
MLTaskManager mlTaskManager,
9399
MLModelManager mlModelManager,
@@ -99,6 +105,7 @@ public CancelBatchJobTransportAction(
99105
this.clusterService = clusterService;
100106
this.scriptService = scriptService;
101107
this.connectorAccessControlHelper = connectorAccessControlHelper;
108+
this.modelAccessControlHelper = modelAccessControlHelper;
102109
this.encryptor = encryptor;
103110
this.mlTaskManager = mlTaskManager;
104111
this.mlModelManager = mlModelManager;
@@ -177,25 +184,39 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
177184
RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS);
178185
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build();
179186
String modelId = mlTask.getModelId();
187+
User user = RestActionUtils.getUserContext(client);
180188

181189
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
182190
ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> {
183-
if (model.getConnector() != null) {
184-
Connector connector = model.getConnector();
185-
executeConnector(connector, mlInput, actionListener);
186-
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
187-
ActionListener<Connector> listener = ActionListener
188-
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
189-
log.error("Failed to get connector " + model.getConnectorId(), e);
190-
actionListener.onFailure(e);
191-
});
192-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
193-
connectorAccessControlHelper
194-
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
191+
modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> {
192+
if (!access) {
193+
actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job"));
194+
} else {
195+
if (model.getConnector() != null) {
196+
Connector connector = model.getConnector();
197+
executeConnector(connector, mlInput, actionListener);
198+
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
199+
ActionListener<Connector> listener = ActionListener
200+
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
201+
log.error("Failed to get connector " + model.getConnectorId(), e);
202+
actionListener.onFailure(e);
203+
});
204+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
205+
connectorAccessControlHelper
206+
.getConnector(
207+
client,
208+
model.getConnectorId(),
209+
ActionListener.runBefore(listener, threadContext::restore)
210+
);
211+
}
212+
} else {
213+
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
214+
}
195215
}
196-
} else {
197-
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
198-
}
216+
}, e -> {
217+
log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e);
218+
actionListener.onFailure(e);
219+
}));
199220
}, e -> {
200221
log.error("Failed to retrieve the ML model with the given ID", e);
201222
actionListener
@@ -211,26 +232,20 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
211232
}
212233

213234
private void executeConnector(Connector connector, MLInput mlInput, ActionListener<MLCancelBatchJobResponse> actionListener) {
214-
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
215-
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name());
216-
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) {
217-
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT);
218-
connector.addAction(connectorAction);
219-
}
220-
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential));
221-
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
222-
.initInstance(connector.getProtocol(), connector, Connector.class);
223-
connectorExecutor.setScriptService(scriptService);
224-
connectorExecutor.setClusterService(clusterService);
225-
connectorExecutor.setClient(client);
226-
connectorExecutor.setXContentRegistry(xContentRegistry);
227-
connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> {
228-
processTaskResponse(taskResponse, actionListener);
229-
}, e -> { actionListener.onFailure(e); }));
230-
} else {
231-
actionListener
232-
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
235+
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name());
236+
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) {
237+
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT);
238+
connector.addAction(connectorAction);
233239
}
240+
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential));
241+
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
242+
connectorExecutor.setScriptService(scriptService);
243+
connectorExecutor.setClusterService(clusterService);
244+
connectorExecutor.setClient(client);
245+
connectorExecutor.setXContentRegistry(xContentRegistry);
246+
connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> {
247+
processTaskResponse(taskResponse, actionListener);
248+
}, e -> { actionListener.onFailure(e); }));
234249
}
235250

236251
private void processTaskResponse(MLTaskResponse taskResponse, ActionListener<MLCancelBatchJobResponse> actionListener) {

plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.opensearch.common.settings.Setting;
4646
import org.opensearch.common.settings.Settings;
4747
import org.opensearch.common.util.concurrent.ThreadContext;
48+
import org.opensearch.commons.authuser.User;
4849
import org.opensearch.core.action.ActionListener;
4950
import org.opensearch.core.rest.RestStatus;
5051
import org.opensearch.core.xcontent.NamedXContentRegistry;
@@ -59,6 +60,7 @@
5960
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
6061
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
6162
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
63+
import org.opensearch.ml.common.exception.MLValidationException;
6264
import org.opensearch.ml.common.input.MLInput;
6365
import org.opensearch.ml.common.output.model.ModelTensorOutput;
6466
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -71,9 +73,11 @@
7173
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
7274
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
7375
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
76+
import org.opensearch.ml.helper.ModelAccessControlHelper;
7477
import org.opensearch.ml.model.MLModelManager;
7578
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
7679
import org.opensearch.ml.task.MLTaskManager;
80+
import org.opensearch.ml.utils.RestActionUtils;
7781
import org.opensearch.script.ScriptService;
7882
import org.opensearch.tasks.Task;
7983
import org.opensearch.transport.TransportService;
@@ -90,6 +94,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
9094
ScriptService scriptService;
9195

9296
ConnectorAccessControlHelper connectorAccessControlHelper;
97+
ModelAccessControlHelper modelAccessControlHelper;
9398
EncryptorImpl encryptor;
9499
MLModelManager mlModelManager;
95100

@@ -111,6 +116,7 @@ public GetTaskTransportAction(
111116
ClusterService clusterService,
112117
ScriptService scriptService,
113118
ConnectorAccessControlHelper connectorAccessControlHelper,
119+
ModelAccessControlHelper modelAccessControlHelper,
114120
EncryptorImpl encryptor,
115121
MLTaskManager mlTaskManager,
116122
MLModelManager mlModelManager,
@@ -123,6 +129,7 @@ public GetTaskTransportAction(
123129
this.clusterService = clusterService;
124130
this.scriptService = scriptService;
125131
this.connectorAccessControlHelper = connectorAccessControlHelper;
132+
this.modelAccessControlHelper = modelAccessControlHelper;
126133
this.encryptor = encryptor;
127134
this.mlTaskManager = mlTaskManager;
128135
this.mlModelManager = mlModelManager;
@@ -238,26 +245,40 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi
238245
RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS);
239246
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build();
240247
String modelId = mlTask.getModelId();
248+
User user = RestActionUtils.getUserContext(client);
241249

242250
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
243251
ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> {
244-
if (model.getConnector() != null) {
245-
Connector connector = model.getConnector();
246-
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
247-
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
248-
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
249-
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
250-
}, e -> {
251-
log.error("Failed to get connector " + model.getConnectorId(), e);
252-
actionListener.onFailure(e);
253-
});
254-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
255-
connectorAccessControlHelper
256-
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
252+
modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> {
253+
if (!access) {
254+
actionListener.onFailure(new MLValidationException("You don't have permission to access this batch job"));
255+
} else {
256+
if (model.getConnector() != null) {
257+
Connector connector = model.getConnector();
258+
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
259+
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
260+
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
261+
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
262+
}, e -> {
263+
log.error("Failed to get connector " + model.getConnectorId(), e);
264+
actionListener.onFailure(e);
265+
});
266+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
267+
connectorAccessControlHelper
268+
.getConnector(
269+
client,
270+
model.getConnectorId(),
271+
ActionListener.runBefore(listener, threadContext::restore)
272+
);
273+
}
274+
} else {
275+
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
276+
}
257277
}
258-
} else {
259-
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
260-
}
278+
}, e -> {
279+
log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e);
280+
actionListener.onFailure(e);
281+
}));
261282
}, e -> {
262283
log.error("Failed to retrieve the ML model for the given task ID", e);
263284
actionListener
@@ -280,26 +301,20 @@ private void executeConnector(
280301
Map<String, Object> remoteJob,
281302
ActionListener<MLTaskGetResponse> actionListener
282303
) {
283-
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
284-
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name());
285-
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) {
286-
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS);
287-
connector.addAction(connectorAction);
288-
}
289-
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential));
290-
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
291-
.initInstance(connector.getProtocol(), connector, Connector.class);
292-
connectorExecutor.setScriptService(scriptService);
293-
connectorExecutor.setClusterService(clusterService);
294-
connectorExecutor.setClient(client);
295-
connectorExecutor.setXContentRegistry(xContentRegistry);
296-
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
297-
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
298-
}, e -> { actionListener.onFailure(e); }));
299-
} else {
300-
actionListener
301-
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
304+
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name());
305+
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) {
306+
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS);
307+
connector.addAction(connectorAction);
302308
}
309+
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential));
310+
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
311+
connectorExecutor.setScriptService(scriptService);
312+
connectorExecutor.setClusterService(clusterService);
313+
connectorExecutor.setClient(client);
314+
connectorExecutor.setXContentRegistry(xContentRegistry);
315+
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
316+
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
317+
}, e -> { actionListener.onFailure(e); }));
303318
}
304319

305320
protected void processTaskResponse(

0 commit comments

Comments
 (0)