Skip to content

Commit ba9735c

Browse files
authored
add security IT (#168)
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 6b5fdef commit ba9735c

File tree

4 files changed

+337
-7
lines changed

4 files changed

+337
-7
lines changed

plugin/build.gradle

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ integTest {
104104
}
105105
}
106106

107+
if (System.getProperty("https") == null || System.getProperty("https") == "false") {
108+
filter {
109+
excludeTestsMatching "org.opensearch.ml.rest.SecureMLRestIT"
110+
}
111+
}
112+
107113
// The 'doFirst' delays till execution time.
108114
doFirst {
109115
// Tell the test JVM if the cluster JVM is running under a debugger so that tests can
@@ -182,7 +188,7 @@ run {
182188
task release(type: Copy, group: 'build') {
183189
dependsOn allprojects*.tasks.build
184190
from(zipTree(project.tasks.bundlePlugin.outputs.files.getSingleFile()))
185-
into "build/plugins/opensearch-machine-learning"
191+
into "build/plugins/opensearch-ml"
186192
includeEmptyDirs = false
187193
}
188194

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

Lines changed: 172 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
import java.net.URI;
1919
import java.net.URISyntaxException;
2020
import java.nio.file.Path;
21+
import java.util.ArrayList;
2122
import java.util.Collections;
2223
import java.util.List;
24+
import java.util.Locale;
2325
import java.util.Map;
2426
import java.util.Objects;
2527
import java.util.Optional;
28+
import java.util.function.Consumer;
2629
import java.util.stream.Collectors;
2730

2831
import org.apache.http.Header;
@@ -51,17 +54,23 @@
5154
import org.opensearch.common.xcontent.XContentParser;
5255
import org.opensearch.common.xcontent.XContentType;
5356
import org.opensearch.commons.rest.SecureRestClientBuilder;
57+
import org.opensearch.ml.common.dataset.MLInputDataset;
58+
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
5459
import org.opensearch.ml.common.parameter.FunctionName;
60+
import org.opensearch.ml.common.parameter.MLAlgoParams;
61+
import org.opensearch.ml.common.parameter.MLInput;
5562
import org.opensearch.ml.stats.ActionName;
5663
import org.opensearch.ml.stats.StatNames;
5764
import org.opensearch.ml.utils.TestData;
5865
import org.opensearch.ml.utils.TestHelper;
5966
import org.opensearch.rest.RestStatus;
67+
import org.opensearch.search.builder.SearchSourceBuilder;
6068
import org.opensearch.test.rest.OpenSearchRestTestCase;
6169

6270
import com.google.common.collect.ImmutableList;
6371
import com.google.common.collect.ImmutableMap;
6472
import com.google.gson.Gson;
73+
import com.google.gson.JsonArray;
6574

6675
public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase {
6776
protected Gson gson = new Gson();
@@ -241,7 +250,7 @@ protected Response ingestIrisData(String indexName) throws IOException {
241250
"POST",
242251
"_bulk?refresh=true",
243252
null,
244-
TestHelper.toHttpEntity(TestData.IRIS_DATA),
253+
TestHelper.toHttpEntity(TestData.IRIS_DATA.replaceAll("iris_data", indexName)),
245254
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
246255
);
247256
assertEquals(RestStatus.OK, TestHelper.restStatus(statsResponse));
@@ -253,7 +262,7 @@ protected void validateStats(
253262
ActionName actionName,
254263
int expectedTotalFailureCount,
255264
int expectedTotalAlgoFailureCount,
256-
int expectedTotalRequestCount,
265+
int expectedMinumnTotalRequestCount,
257266
int expectedTotalAlgoRequestCount
258267
) throws IOException {
259268
Response statsResponse = TestHelper.makeRequest(client(), "GET", "_plugins/_ml/stats", null, "", null);
@@ -284,8 +293,7 @@ protected void validateStats(
284293
}
285294
assertEquals(expectedTotalFailureCount, totalFailureCount);
286295
assertEquals(expectedTotalAlgoFailureCount, totalAlgoFailureCount);
287-
// ToDo: this line makes this test flaky as other tests makes the request count not predictable
288-
// assertEquals(expectedTotalRequestCount, totalRequestCount);
296+
assertTrue(totalRequestCount >= expectedMinumnTotalRequestCount);
289297
assertEquals(expectedTotalAlgoRequestCount, totalAlgoRequestCount);
290298
}
291299

@@ -296,4 +304,164 @@ protected Response ingestModelData() throws IOException {
296304
assertNotNull(trainModelResponse);
297305
return trainModelResponse;
298306
}
307+
308+
public Response createIndexRole(String role, String index) throws IOException {
309+
return TestHelper
310+
.makeRequest(
311+
client(),
312+
"PUT",
313+
"/_opendistro/_security/api/roles/" + role,
314+
null,
315+
TestHelper
316+
.toHttpEntity(
317+
"{\n"
318+
+ "\"cluster_permissions\": [\n"
319+
+ "],\n"
320+
+ "\"index_permissions\": [\n"
321+
+ "{\n"
322+
+ "\"index_patterns\": [\n"
323+
+ "\""
324+
+ index
325+
+ "\"\n"
326+
+ "],\n"
327+
+ "\"dls\": \"\",\n"
328+
+ "\"fls\": [],\n"
329+
+ "\"masked_fields\": [],\n"
330+
+ "\"allowed_actions\": [\n"
331+
+ "\"crud\",\n"
332+
+ "\"indices:admin/create\"\n"
333+
+ "]\n"
334+
+ "}\n"
335+
+ "],\n"
336+
+ "\"tenant_permissions\": []\n"
337+
+ "}"
338+
),
339+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
340+
);
341+
}
342+
343+
public Response createSearchRole(String role, String index) throws IOException {
344+
return TestHelper
345+
.makeRequest(
346+
client(),
347+
"PUT",
348+
"/_opendistro/_security/api/roles/" + role,
349+
null,
350+
TestHelper
351+
.toHttpEntity(
352+
"{\n"
353+
+ "\"cluster_permissions\": [\n"
354+
+ "],\n"
355+
+ "\"index_permissions\": [\n"
356+
+ "{\n"
357+
+ "\"index_patterns\": [\n"
358+
+ "\""
359+
+ index
360+
+ "\"\n"
361+
+ "],\n"
362+
+ "\"dls\": \"\",\n"
363+
+ "\"fls\": [],\n"
364+
+ "\"masked_fields\": [],\n"
365+
+ "\"allowed_actions\": [\n"
366+
+ "\"indices:data/read/search\"\n"
367+
+ "]\n"
368+
+ "}\n"
369+
+ "],\n"
370+
+ "\"tenant_permissions\": []\n"
371+
+ "}"
372+
),
373+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
374+
);
375+
}
376+
377+
public Response createUser(String name, String password, ArrayList<String> backendRoles) throws IOException {
378+
JsonArray backendRolesString = new JsonArray();
379+
for (int i = 0; i < backendRoles.size(); i++) {
380+
backendRolesString.add(backendRoles.get(i));
381+
}
382+
return TestHelper
383+
.makeRequest(
384+
client(),
385+
"PUT",
386+
"/_opendistro/_security/api/internalusers/" + name,
387+
null,
388+
TestHelper
389+
.toHttpEntity(
390+
" {\n"
391+
+ "\"password\": \""
392+
+ password
393+
+ "\",\n"
394+
+ "\"backend_roles\": "
395+
+ backendRolesString
396+
+ ",\n"
397+
+ "\"attributes\": {\n"
398+
+ "}} "
399+
),
400+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
401+
);
402+
}
403+
404+
public Response deleteUser(String user) throws IOException {
405+
return TestHelper
406+
.makeRequest(
407+
client(),
408+
"DELETE",
409+
"/_opendistro/_security/api/internalusers/" + user,
410+
null,
411+
"",
412+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
413+
);
414+
}
415+
416+
public Response createRoleMapping(String role, ArrayList<String> users) throws IOException {
417+
JsonArray usersString = new JsonArray();
418+
for (int i = 0; i < users.size(); i++) {
419+
usersString.add(users.get(i));
420+
}
421+
return TestHelper
422+
.makeRequest(
423+
client(),
424+
"PUT",
425+
"/_opendistro/_security/api/rolesmapping/" + role,
426+
null,
427+
TestHelper
428+
.toHttpEntity(
429+
"{\n" + " \"backend_roles\" : [ ],\n" + " \"hosts\" : [ ],\n" + " \"users\" : " + usersString + "\n" + "}"
430+
),
431+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
432+
);
433+
}
434+
435+
public void trainAndPredict(
436+
RestClient client,
437+
FunctionName functionName,
438+
String indexName,
439+
MLAlgoParams params,
440+
SearchSourceBuilder searchSourceBuilder,
441+
Consumer<Map<String, Object>> function
442+
) throws IOException {
443+
MLInputDataset inputData = SearchQueryInputDataset
444+
.builder()
445+
.indices(ImmutableList.of(indexName))
446+
.searchSourceBuilder(searchSourceBuilder)
447+
.build();
448+
MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
449+
Response response = TestHelper
450+
.makeRequest(
451+
client,
452+
"POST",
453+
"/_plugins/_ml/_train_predict/" + functionName.name().toLowerCase(Locale.ROOT),
454+
ImmutableMap.of(),
455+
TestHelper.toHttpEntity(kmeansInput),
456+
null
457+
);
458+
HttpEntity entity = response.getEntity();
459+
assertNotNull(response);
460+
String entityString = TestHelper.httpEntityToString(entity);
461+
Map map = gson.fromJson(entityString, Map.class);
462+
Map<String, Object> predictionResult = (Map<String, Object>) map.get("prediction_result");
463+
if (function != null) {
464+
function.accept(predictionResult);
465+
}
466+
}
299467
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLTrainAndPredictIT.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import java.util.function.Consumer;
1313

1414
import org.apache.http.HttpEntity;
15+
import org.junit.After;
16+
import org.junit.Before;
1517
import org.opensearch.client.Response;
1618
import org.opensearch.index.query.MatchAllQueryBuilder;
1719
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -27,11 +29,20 @@
2729
import com.google.common.collect.ImmutableMap;
2830

2931
public class RestMLTrainAndPredictIT extends MLCommonsRestTestCase {
30-
private String irisIndex = "iris_data";
32+
private String irisIndex = "iris_data_train_predict_it";
33+
34+
@Before
35+
public void setup() throws IOException {
36+
ingestIrisData(irisIndex);
37+
}
38+
39+
@After
40+
public void deleteIndices() throws IOException {
41+
deleteIndexWithAdminClient(irisIndex);
42+
}
3143

3244
public void testTrainAndPredictKmeans() throws IOException {
3345
validateStats(FunctionName.KMEANS, ActionName.TRAIN_PREDICT, 0, 0, 0, 0);
34-
ingestIrisData(irisIndex);
3546
trainAndPredictKmeansWithCustomParam();
3647
validateStats(FunctionName.KMEANS, ActionName.TRAIN_PREDICT, 0, 0, 1, 1);
3748

0 commit comments

Comments
 (0)