Skip to content

Commit bf1d068

Browse files
fix no permission to create model/task index bug;add security IT for train/predict API (#177) (#181)
* fix no permission to create model/task index bug;add security IT for train/predict API Signed-off-by: Yaliang Wu <[email protected]> * add more security IT for readonly user Signed-off-by: Yaliang Wu <[email protected]> * throw exception if delete model/task successfully for readonly user Signed-off-by: Yaliang Wu <[email protected]> (cherry picked from commit f12ca76) Co-authored-by: Yaliang Wu <[email protected]>
1 parent 75c94a2 commit bf1d068

File tree

4 files changed

+285
-24
lines changed

4 files changed

+285
-24
lines changed

plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
import org.opensearch.action.ActionListener;
1414
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
15+
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
1516
import org.opensearch.client.Client;
1617
import org.opensearch.cluster.service.ClusterService;
18+
import org.opensearch.common.util.concurrent.ThreadContext;
1719
import org.opensearch.common.xcontent.XContentType;
1820

1921
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@@ -87,19 +89,24 @@ public void initMLTaskIndex(ActionListener<Boolean> listener) {
8789

8890
public void initMLIndexIfAbsent(String indexName, String mapping, ActionListener<Boolean> listener) {
8991
if (!clusterService.state().metadata().hasIndex(indexName)) {
90-
CreateIndexRequest request = new CreateIndexRequest(indexName).mapping("_doc", mapping, XContentType.JSON);
91-
92-
client.admin().indices().create(request, ActionListener.wrap(r -> {
93-
if (r.isAcknowledged()) {
94-
log.info("create index:{}", indexName);
95-
listener.onResponse(true);
96-
} else {
97-
listener.onResponse(false);
98-
}
99-
}, e -> {
100-
log.error("Failed to create index " + indexName, e);
92+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
93+
ActionListener<CreateIndexResponse> actionListener = ActionListener.wrap(r -> {
94+
if (r.isAcknowledged()) {
95+
log.info("create index:{}", indexName);
96+
listener.onResponse(true);
97+
} else {
98+
listener.onResponse(false);
99+
}
100+
}, e -> {
101+
log.error("Failed to create index " + indexName, e);
102+
listener.onFailure(e);
103+
});
104+
CreateIndexRequest request = new CreateIndexRequest(indexName).mapping("_doc", mapping, XContentType.JSON);
105+
client.admin().indices().create(request, ActionListener.runBefore(actionListener, () -> threadContext.restore()));
106+
} catch (Exception e) {
107+
log.error("Failed to init index " + indexName, e);
101108
listener.onFailure(e);
102-
}));
109+
}
103110
} else {
104111
log.info("index:{} is already created", indexName);
105112
listener.onResponse(true);

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

Lines changed: 101 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,10 @@ protected Response ingestIrisData(String indexName) throws IOException {
260260
protected void validateStats(
261261
FunctionName functionName,
262262
ActionName actionName,
263-
int expectedTotalFailureCount,
264-
int expectedTotalAlgoFailureCount,
265-
int expectedMinumnTotalRequestCount,
266-
int expectedTotalAlgoRequestCount
263+
int expectedMinimumTotalFailureCount,
264+
int expectedMinimumTotalAlgoFailureCount,
265+
int expectedMinimumTotalRequestCount,
266+
int expectedMinimumTotalAlgoRequestCount
267267
) throws IOException {
268268
Response statsResponse = TestHelper.makeRequest(client(), "GET", "_plugins/_ml/stats", null, "", null);
269269
HttpEntity entity = statsResponse.getEntity();
@@ -291,10 +291,10 @@ protected void validateStats(
291291
totalAlgoRequestCount += (Double) nodeStatsMap.get(requestCountStat);
292292
}
293293
}
294-
assertEquals(expectedTotalFailureCount, totalFailureCount);
295-
assertEquals(expectedTotalAlgoFailureCount, totalAlgoFailureCount);
296-
assertTrue(totalRequestCount >= expectedMinumnTotalRequestCount);
297-
assertEquals(expectedTotalAlgoRequestCount, totalAlgoRequestCount);
294+
assertTrue(totalFailureCount >= expectedMinimumTotalFailureCount);
295+
assertTrue(totalAlgoFailureCount >= expectedMinimumTotalAlgoFailureCount);
296+
assertTrue(totalRequestCount >= expectedMinimumTotalRequestCount);
297+
assertTrue(totalAlgoRequestCount >= expectedMinimumTotalAlgoRequestCount);
298298
}
299299

300300
protected Response ingestModelData() throws IOException {
@@ -464,4 +464,97 @@ public void trainAndPredict(
464464
function.accept(predictionResult);
465465
}
466466
}
467+
468+
public void train(
469+
RestClient client,
470+
FunctionName functionName,
471+
String indexName,
472+
MLAlgoParams params,
473+
SearchSourceBuilder searchSourceBuilder,
474+
Consumer<Map<String, Object>> function,
475+
boolean async
476+
) throws IOException {
477+
MLInputDataset inputData = SearchQueryInputDataset
478+
.builder()
479+
.indices(ImmutableList.of(indexName))
480+
.searchSourceBuilder(searchSourceBuilder)
481+
.build();
482+
MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
483+
String endpoint = "/_plugins/_ml/_train/" + functionName.name().toLowerCase(Locale.ROOT);
484+
if (async) {
485+
endpoint += "?async=true";
486+
}
487+
Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
488+
verifyResponse(function, response);
489+
}
490+
491+
public void predict(
492+
RestClient client,
493+
FunctionName functionName,
494+
String modelId,
495+
String indexName,
496+
MLAlgoParams params,
497+
SearchSourceBuilder searchSourceBuilder,
498+
Consumer<Map<String, Object>> function
499+
) throws IOException {
500+
MLInputDataset inputData = SearchQueryInputDataset
501+
.builder()
502+
.indices(ImmutableList.of(indexName))
503+
.searchSourceBuilder(searchSourceBuilder)
504+
.build();
505+
MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
506+
String endpoint = "/_plugins/_ml/_predict/" + functionName.name().toLowerCase(Locale.ROOT) + "/" + modelId;
507+
Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
508+
verifyResponse(function, response);
509+
}
510+
511+
public void getModel(RestClient client, String modelId, Consumer<Map<String, Object>> function) throws IOException {
512+
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/models/" + modelId, null, "", null);
513+
verifyResponse(function, response);
514+
}
515+
516+
public void getTask(RestClient client, String taskId, Consumer<Map<String, Object>> function) throws IOException {
517+
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null);
518+
verifyResponse(function, response);
519+
}
520+
521+
public void deleteModel(RestClient client, String modelId, Consumer<Map<String, Object>> function) throws IOException {
522+
Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/models/" + modelId, null, "", null);
523+
verifyResponse(function, response);
524+
}
525+
526+
public void deleteTask(RestClient client, String taskId, Consumer<Map<String, Object>> function) throws IOException {
527+
Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/tasks/" + taskId, null, "", null);
528+
verifyResponse(function, response);
529+
}
530+
531+
public void searchModelsWithAlgoName(RestClient client, String algoName, Consumer<Map<String, Object>> function) throws IOException {
532+
String query = String.format(Locale.ROOT, "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"algorithm\":\"%s\"}}]}}}", algoName);
533+
searchModels(client, query, function);
534+
}
535+
536+
public void searchModels(RestClient client, String query, Consumer<Map<String, Object>> function) throws IOException {
537+
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/models/_search", null, query, null);
538+
verifyResponse(function, response);
539+
}
540+
541+
public void searchTasksWithAlgoName(RestClient client, String algoName, Consumer<Map<String, Object>> function) throws IOException {
542+
String query = String.format(Locale.ROOT, "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"function_name\":\"%s\"}}]}}}", algoName);
543+
searchTasks(client, query, function);
544+
}
545+
546+
public void searchTasks(RestClient client, String query, Consumer<Map<String, Object>> function) throws IOException {
547+
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/tasks/_search", null, query, null);
548+
verifyResponse(function, response);
549+
}
550+
551+
private void verifyResponse(Consumer<Map<String, Object>> function, Response response) throws IOException {
552+
HttpEntity entity = response.getEntity();
553+
assertNotNull(response);
554+
String entityString = TestHelper.httpEntityToString(entity);
555+
Map<String, Object> map = gson.fromJson(entityString, Map.class);
556+
if (function != null) {
557+
function.accept(map);
558+
}
559+
}
467560
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.junit.Before;
1111
import org.junit.Rule;
12-
import org.junit.Test;
1312
import org.junit.rules.ExpectedException;
1413
import org.opensearch.common.Strings;
1514
import org.opensearch.rest.RestHandler;
@@ -27,20 +26,17 @@ public void setup() {
2726
restMLGetModelAction = new RestMLGetModelAction();
2827
}
2928

30-
@Test
3129
public void testConstructor() {
3230
RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction();
3331
assertNotNull(mlGetModelAction);
3432
}
3533

36-
@Test
3734
public void testGetName() {
3835
String actionName = restMLGetModelAction.getName();
3936
assertFalse(Strings.isNullOrEmpty(actionName));
4037
assertEquals("ml_get_model_action", actionName);
4138
}
4239

43-
@Test
4440
public void testRoutes() {
4541
List<RestHandler.Route> routes = restMLGetModelAction.routes();
4642
assertNotNull(routes);

plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.io.IOException;
99
import java.util.ArrayList;
1010
import java.util.Arrays;
11+
import java.util.Map;
1112

1213
import org.apache.http.HttpHost;
1314
import org.junit.After;
@@ -20,8 +21,11 @@
2021
import org.opensearch.index.query.MatchAllQueryBuilder;
2122
import org.opensearch.ml.common.parameter.FunctionName;
2223
import org.opensearch.ml.common.parameter.KMeansParams;
24+
import org.opensearch.ml.common.parameter.MLTaskState;
2325
import org.opensearch.search.builder.SearchSourceBuilder;
2426

27+
import com.google.common.base.Throwables;
28+
2529
public class SecureMLRestIT extends MLCommonsRestTestCase {
2630
private String irisIndex = "iris_data_secure_ml_it";
2731

@@ -129,6 +133,20 @@ public void testTrainAndPredictWithFullMLAccessNoIndexAccess() throws IOExceptio
129133
);
130134
}
131135

136+
public void testTrainWithReadOnlyMLAccess() throws IOException {
137+
exceptionRule.expect(ResponseException.class);
138+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]");
139+
KMeansParams kMeansParams = KMeansParams.builder().build();
140+
train(mlReadOnlyClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, null, false);
141+
}
142+
143+
public void testPredictWithReadOnlyMLAccess() throws IOException {
144+
exceptionRule.expect(ResponseException.class);
145+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/predict]");
146+
KMeansParams kMeansParams = KMeansParams.builder().build();
147+
predict(mlReadOnlyClient, FunctionName.KMEANS, "modelId", irisIndex, kMeansParams, searchSourceBuilder, null);
148+
}
149+
132150
public void testTrainAndPredictWithFullAccess() throws IOException {
133151
trainAndPredict(
134152
mlFullAccessClient,
@@ -142,4 +160,151 @@ public void testTrainAndPredictWithFullAccess() throws IOException {
142160
}
143161
);
144162
}
163+
164+
public void testTrainModelWithFullAccessThenPredict() throws IOException {
165+
KMeansParams kMeansParams = KMeansParams.builder().build();
166+
// train model
167+
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
168+
String modelId = (String) trainResult.get("model_id");
169+
assertNotNull(modelId);
170+
String status = (String) trainResult.get("status");
171+
assertEquals(MLTaskState.COMPLETED.name(), status);
172+
try {
173+
getModel(mlFullAccessClient, modelId, model -> {
174+
String algorithm = (String) model.get("algorithm");
175+
assertEquals(FunctionName.KMEANS.name(), algorithm);
176+
});
177+
} catch (IOException e) {
178+
assertNull(e);
179+
}
180+
try {
181+
// predict with trained model
182+
predict(mlFullAccessClient, FunctionName.KMEANS, modelId, irisIndex, kMeansParams, searchSourceBuilder, predictResult -> {
183+
String predictStatus = (String) predictResult.get("status");
184+
assertEquals(MLTaskState.COMPLETED.name(), predictStatus);
185+
Map<String, Object> predictionResult = (Map<String, Object>) predictResult.get("prediction_result");
186+
ArrayList rows = (ArrayList) predictionResult.get("rows");
187+
assertTrue(rows.size() > 1);
188+
});
189+
} catch (IOException e) {
190+
assertNull(e);
191+
}
192+
}, false);
193+
}
194+
195+
public void testTrainModelInAsyncWayWithFullAccess() throws IOException {
196+
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, KMeansParams.builder().build(), searchSourceBuilder, trainResult -> {
197+
assertFalse(trainResult.containsKey("model_id"));
198+
String taskId = (String) trainResult.get("task_id");
199+
assertNotNull(taskId);
200+
String status = (String) trainResult.get("status");
201+
assertEquals(MLTaskState.CREATED.name(), status);
202+
try {
203+
getTask(mlFullAccessClient, taskId, task -> {
204+
String algorithm = (String) task.get("function_name");
205+
assertEquals(FunctionName.KMEANS.name(), algorithm);
206+
});
207+
} catch (IOException e) {
208+
assertNull(e);
209+
}
210+
}, true);
211+
}
212+
213+
public void testReadOnlyUser_CanGetModel_CanNotDeleteModel() throws IOException {
214+
KMeansParams kMeansParams = KMeansParams.builder().build();
215+
// train model with full access client
216+
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
217+
String modelId = (String) trainResult.get("model_id");
218+
assertNotNull(modelId);
219+
String status = (String) trainResult.get("status");
220+
assertEquals(MLTaskState.COMPLETED.name(), status);
221+
try {
222+
// get model with readonly client
223+
getModel(mlReadOnlyClient, modelId, model -> {
224+
String algorithm = (String) model.get("algorithm");
225+
assertEquals(FunctionName.KMEANS.name(), algorithm);
226+
});
227+
} catch (IOException e) {
228+
assertNull(e);
229+
}
230+
try {
231+
// Failed to delete model with read only client
232+
deleteModel(mlReadOnlyClient, modelId, null);
233+
throw new RuntimeException("Delete model for readonly user does not fail");
234+
} catch (Exception e) {
235+
assertEquals(ResponseException.class, e.getClass());
236+
assertTrue(Throwables.getStackTraceAsString(e).contains("no permissions for [cluster:admin/opensearch/ml/models/delete]"));
237+
}
238+
}, false);
239+
}
240+
241+
public void testReadOnlyUser_CanGetTask_CanNotDeleteTask() throws IOException {
242+
KMeansParams kMeansParams = KMeansParams.builder().build();
243+
// train model with full access client
244+
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
245+
assertFalse(trainResult.containsKey("model_id"));
246+
String taskId = (String) trainResult.get("task_id");
247+
assertNotNull(taskId);
248+
String status = (String) trainResult.get("status");
249+
assertEquals(MLTaskState.CREATED.name(), status);
250+
try {
251+
// get task with readonly client
252+
getTask(mlReadOnlyClient, taskId, task -> {
253+
String algorithm = (String) task.get("function_name");
254+
assertEquals(FunctionName.KMEANS.name(), algorithm);
255+
});
256+
} catch (IOException e) {
257+
assertNull(e);
258+
}
259+
try {
260+
// Failed to delete task with read only client
261+
deleteTask(mlReadOnlyClient, taskId, null);
262+
throw new RuntimeException("Delete task for readonly user does not fail");
263+
} catch (Exception e) {
264+
assertEquals(ResponseException.class, e.getClass());
265+
assertTrue(Throwables.getStackTraceAsString(e).contains("no permissions for [cluster:admin/opensearch/ml/tasks/delete]"));
266+
}
267+
}, true);
268+
}
269+
270+
public void testReadOnlyUser_CanSearchModels() throws IOException {
271+
KMeansParams kMeansParams = KMeansParams.builder().build();
272+
// train model with full access client
273+
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
274+
String modelId = (String) trainResult.get("model_id");
275+
assertNotNull(modelId);
276+
String status = (String) trainResult.get("status");
277+
assertEquals(MLTaskState.COMPLETED.name(), status);
278+
try {
279+
// search model with readonly client
280+
searchModelsWithAlgoName(mlReadOnlyClient, FunctionName.KMEANS.name(), models -> {
281+
ArrayList<Object> hits = (ArrayList) ((Map<String, Object>) models.get("hits")).get("hits");
282+
assertTrue(hits.size() > 0);
283+
});
284+
} catch (IOException e) {
285+
assertNull(e);
286+
}
287+
}, false);
288+
}
289+
290+
public void testReadOnlyUser_CanSearchTasks() throws IOException {
291+
KMeansParams kMeansParams = KMeansParams.builder().build();
292+
// train model with full access client
293+
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
294+
assertFalse(trainResult.containsKey("model_id"));
295+
String taskId = (String) trainResult.get("task_id");
296+
assertNotNull(taskId);
297+
String status = (String) trainResult.get("status");
298+
assertEquals(MLTaskState.CREATED.name(), status);
299+
try {
300+
// search tasks with readonly client
301+
searchTasksWithAlgoName(mlReadOnlyClient, FunctionName.KMEANS.name(), tasks -> {
302+
ArrayList<Object> hits = (ArrayList) ((Map<String, Object>) tasks.get("hits")).get("hits");
303+
assertTrue(hits.size() > 0);
304+
});
305+
} catch (IOException e) {
306+
assertNull(e);
307+
}
308+
}, true);
309+
}
145310
}

0 commit comments

Comments
 (0)