1818import java .net .URI ;
1919import java .net .URISyntaxException ;
2020import java .nio .file .Path ;
21+ import java .util .ArrayList ;
2122import java .util .Collections ;
2223import java .util .List ;
24+ import java .util .Locale ;
2325import java .util .Map ;
2426import java .util .Objects ;
2527import java .util .Optional ;
28+ import java .util .function .Consumer ;
2629import java .util .stream .Collectors ;
2730
2831import org .apache .http .Header ;
5154import org .opensearch .common .xcontent .XContentParser ;
5255import org .opensearch .common .xcontent .XContentType ;
5356import org .opensearch .commons .rest .SecureRestClientBuilder ;
57+ import org .opensearch .ml .common .dataset .MLInputDataset ;
58+ import org .opensearch .ml .common .dataset .SearchQueryInputDataset ;
5459import org .opensearch .ml .common .parameter .FunctionName ;
60+ import org .opensearch .ml .common .parameter .MLAlgoParams ;
61+ import org .opensearch .ml .common .parameter .MLInput ;
5562import org .opensearch .ml .stats .ActionName ;
5663import org .opensearch .ml .stats .StatNames ;
5764import org .opensearch .ml .utils .TestData ;
5865import org .opensearch .ml .utils .TestHelper ;
5966import org .opensearch .rest .RestStatus ;
67+ import org .opensearch .search .builder .SearchSourceBuilder ;
6068import org .opensearch .test .rest .OpenSearchRestTestCase ;
6169
6270import com .google .common .collect .ImmutableList ;
6371import com .google .common .collect .ImmutableMap ;
6472import com .google .gson .Gson ;
73+ import com .google .gson .JsonArray ;
6574
6675public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase {
6776 protected Gson gson = new Gson ();
@@ -241,7 +250,7 @@ protected Response ingestIrisData(String indexName) throws IOException {
241250 "POST" ,
242251 "_bulk?refresh=true" ,
243252 null ,
244- TestHelper .toHttpEntity (TestData .IRIS_DATA ),
253+ TestHelper .toHttpEntity (TestData .IRIS_DATA . replaceAll ( "iris_data" , indexName ) ),
245254 ImmutableList .of (new BasicHeader (HttpHeaders .USER_AGENT , "" ))
246255 );
247256 assertEquals (RestStatus .OK , TestHelper .restStatus (statsResponse ));
@@ -253,7 +262,7 @@ protected void validateStats(
253262 ActionName actionName ,
254263 int expectedTotalFailureCount ,
255264 int expectedTotalAlgoFailureCount ,
256- int expectedTotalRequestCount ,
265+ int expectedMinumnTotalRequestCount ,
257266 int expectedTotalAlgoRequestCount
258267 ) throws IOException {
259268 Response statsResponse = TestHelper .makeRequest (client (), "GET" , "_plugins/_ml/stats" , null , "" , null );
@@ -284,8 +293,7 @@ protected void validateStats(
284293 }
285294 assertEquals (expectedTotalFailureCount , totalFailureCount );
286295 assertEquals (expectedTotalAlgoFailureCount , totalAlgoFailureCount );
287- // ToDo: this line makes this test flaky as other tests makes the request count not predictable
288- // assertEquals(expectedTotalRequestCount, totalRequestCount);
296+ assertTrue (totalRequestCount >= expectedMinumnTotalRequestCount );
289297 assertEquals (expectedTotalAlgoRequestCount , totalAlgoRequestCount );
290298 }
291299
@@ -296,4 +304,164 @@ protected Response ingestModelData() throws IOException {
296304 assertNotNull (trainModelResponse );
297305 return trainModelResponse ;
298306 }
307+
308+ public Response createIndexRole (String role , String index ) throws IOException {
309+ return TestHelper
310+ .makeRequest (
311+ client (),
312+ "PUT" ,
313+ "/_opendistro/_security/api/roles/" + role ,
314+ null ,
315+ TestHelper
316+ .toHttpEntity (
317+ "{\n "
318+ + "\" cluster_permissions\" : [\n "
319+ + "],\n "
320+ + "\" index_permissions\" : [\n "
321+ + "{\n "
322+ + "\" index_patterns\" : [\n "
323+ + "\" "
324+ + index
325+ + "\" \n "
326+ + "],\n "
327+ + "\" dls\" : \" \" ,\n "
328+ + "\" fls\" : [],\n "
329+ + "\" masked_fields\" : [],\n "
330+ + "\" allowed_actions\" : [\n "
331+ + "\" crud\" ,\n "
332+ + "\" indices:admin/create\" \n "
333+ + "]\n "
334+ + "}\n "
335+ + "],\n "
336+ + "\" tenant_permissions\" : []\n "
337+ + "}"
338+ ),
339+ ImmutableList .of (new BasicHeader (HttpHeaders .USER_AGENT , "Kibana" ))
340+ );
341+ }
342+
343+ public Response createSearchRole (String role , String index ) throws IOException {
344+ return TestHelper
345+ .makeRequest (
346+ client (),
347+ "PUT" ,
348+ "/_opendistro/_security/api/roles/" + role ,
349+ null ,
350+ TestHelper
351+ .toHttpEntity (
352+ "{\n "
353+ + "\" cluster_permissions\" : [\n "
354+ + "],\n "
355+ + "\" index_permissions\" : [\n "
356+ + "{\n "
357+ + "\" index_patterns\" : [\n "
358+ + "\" "
359+ + index
360+ + "\" \n "
361+ + "],\n "
362+ + "\" dls\" : \" \" ,\n "
363+ + "\" fls\" : [],\n "
364+ + "\" masked_fields\" : [],\n "
365+ + "\" allowed_actions\" : [\n "
366+ + "\" indices:data/read/search\" \n "
367+ + "]\n "
368+ + "}\n "
369+ + "],\n "
370+ + "\" tenant_permissions\" : []\n "
371+ + "}"
372+ ),
373+ ImmutableList .of (new BasicHeader (HttpHeaders .USER_AGENT , "Kibana" ))
374+ );
375+ }
376+
377+ public Response createUser (String name , String password , ArrayList <String > backendRoles ) throws IOException {
378+ JsonArray backendRolesString = new JsonArray ();
379+ for (int i = 0 ; i < backendRoles .size (); i ++) {
380+ backendRolesString .add (backendRoles .get (i ));
381+ }
382+ return TestHelper
383+ .makeRequest (
384+ client (),
385+ "PUT" ,
386+ "/_opendistro/_security/api/internalusers/" + name ,
387+ null ,
388+ TestHelper
389+ .toHttpEntity (
390+ " {\n "
391+ + "\" password\" : \" "
392+ + password
393+ + "\" ,\n "
394+ + "\" backend_roles\" : "
395+ + backendRolesString
396+ + ",\n "
397+ + "\" attributes\" : {\n "
398+ + "}} "
399+ ),
400+ ImmutableList .of (new BasicHeader (HttpHeaders .USER_AGENT , "Kibana" ))
401+ );
402+ }
403+
404+ public Response deleteUser (String user ) throws IOException {
405+ return TestHelper
406+ .makeRequest (
407+ client (),
408+ "DELETE" ,
409+ "/_opendistro/_security/api/internalusers/" + user ,
410+ null ,
411+ "" ,
412+ ImmutableList .of (new BasicHeader (HttpHeaders .USER_AGENT , "Kibana" ))
413+ );
414+ }
415+
416+ public Response createRoleMapping (String role , ArrayList <String > users ) throws IOException {
417+ JsonArray usersString = new JsonArray ();
418+ for (int i = 0 ; i < users .size (); i ++) {
419+ usersString .add (users .get (i ));
420+ }
421+ return TestHelper
422+ .makeRequest (
423+ client (),
424+ "PUT" ,
425+ "/_opendistro/_security/api/rolesmapping/" + role ,
426+ null ,
427+ TestHelper
428+ .toHttpEntity (
429+ "{\n " + " \" backend_roles\" : [ ],\n " + " \" hosts\" : [ ],\n " + " \" users\" : " + usersString + "\n " + "}"
430+ ),
431+ ImmutableList .of (new BasicHeader (HttpHeaders .USER_AGENT , "Kibana" ))
432+ );
433+ }
434+
435+ public void trainAndPredict (
436+ RestClient client ,
437+ FunctionName functionName ,
438+ String indexName ,
439+ MLAlgoParams params ,
440+ SearchSourceBuilder searchSourceBuilder ,
441+ Consumer <Map <String , Object >> function
442+ ) throws IOException {
443+ MLInputDataset inputData = SearchQueryInputDataset
444+ .builder ()
445+ .indices (ImmutableList .of (indexName ))
446+ .searchSourceBuilder (searchSourceBuilder )
447+ .build ();
448+ MLInput kmeansInput = MLInput .builder ().algorithm (functionName ).parameters (params ).inputDataset (inputData ).build ();
449+ Response response = TestHelper
450+ .makeRequest (
451+ client ,
452+ "POST" ,
453+ "/_plugins/_ml/_train_predict/" + functionName .name ().toLowerCase (Locale .ROOT ),
454+ ImmutableMap .of (),
455+ TestHelper .toHttpEntity (kmeansInput ),
456+ null
457+ );
458+ HttpEntity entity = response .getEntity ();
459+ assertNotNull (response );
460+ String entityString = TestHelper .httpEntityToString (entity );
461+ Map map = gson .fromJson (entityString , Map .class );
462+ Map <String , Object > predictionResult = (Map <String , Object >) map .get ("prediction_result" );
463+ if (function != null ) {
464+ function .accept (predictionResult );
465+ }
466+ }
299467}
0 commit comments