88import java .io .IOException ;
99import java .util .ArrayList ;
1010import java .util .Arrays ;
11+ import java .util .Map ;
1112
1213import org .apache .http .HttpHost ;
1314import org .junit .After ;
2021import org .opensearch .index .query .MatchAllQueryBuilder ;
2122import org .opensearch .ml .common .parameter .FunctionName ;
2223import org .opensearch .ml .common .parameter .KMeansParams ;
24+ import org .opensearch .ml .common .parameter .MLTaskState ;
2325import org .opensearch .search .builder .SearchSourceBuilder ;
2426
27+ import com .google .common .base .Throwables ;
28+
2529public class SecureMLRestIT extends MLCommonsRestTestCase {
2630 private String irisIndex = "iris_data_secure_ml_it" ;
2731
@@ -129,6 +133,20 @@ public void testTrainAndPredictWithFullMLAccessNoIndexAccess() throws IOExceptio
129133 );
130134 }
131135
136+ public void testTrainWithReadOnlyMLAccess () throws IOException {
137+ exceptionRule .expect (ResponseException .class );
138+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/train]" );
139+ KMeansParams kMeansParams = KMeansParams .builder ().build ();
140+ train (mlReadOnlyClient , FunctionName .KMEANS , irisIndex , kMeansParams , searchSourceBuilder , null , false );
141+ }
142+
143+ public void testPredictWithReadOnlyMLAccess () throws IOException {
144+ exceptionRule .expect (ResponseException .class );
145+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/predict]" );
146+ KMeansParams kMeansParams = KMeansParams .builder ().build ();
147+ predict (mlReadOnlyClient , FunctionName .KMEANS , "modelId" , irisIndex , kMeansParams , searchSourceBuilder , null );
148+ }
149+
132150 public void testTrainAndPredictWithFullAccess () throws IOException {
133151 trainAndPredict (
134152 mlFullAccessClient ,
@@ -142,4 +160,151 @@ public void testTrainAndPredictWithFullAccess() throws IOException {
142160 }
143161 );
144162 }
163+
164+ public void testTrainModelWithFullAccessThenPredict () throws IOException {
165+ KMeansParams kMeansParams = KMeansParams .builder ().build ();
166+ // train model
167+ train (mlFullAccessClient , FunctionName .KMEANS , irisIndex , kMeansParams , searchSourceBuilder , trainResult -> {
168+ String modelId = (String ) trainResult .get ("model_id" );
169+ assertNotNull (modelId );
170+ String status = (String ) trainResult .get ("status" );
171+ assertEquals (MLTaskState .COMPLETED .name (), status );
172+ try {
173+ getModel (mlFullAccessClient , modelId , model -> {
174+ String algorithm = (String ) model .get ("algorithm" );
175+ assertEquals (FunctionName .KMEANS .name (), algorithm );
176+ });
177+ } catch (IOException e ) {
178+ assertNull (e );
179+ }
180+ try {
181+ // predict with trained model
182+ predict (mlFullAccessClient , FunctionName .KMEANS , modelId , irisIndex , kMeansParams , searchSourceBuilder , predictResult -> {
183+ String predictStatus = (String ) predictResult .get ("status" );
184+ assertEquals (MLTaskState .COMPLETED .name (), predictStatus );
185+ Map <String , Object > predictionResult = (Map <String , Object >) predictResult .get ("prediction_result" );
186+ ArrayList rows = (ArrayList ) predictionResult .get ("rows" );
187+ assertTrue (rows .size () > 1 );
188+ });
189+ } catch (IOException e ) {
190+ assertNull (e );
191+ }
192+ }, false );
193+ }
194+
195+ public void testTrainModelInAsyncWayWithFullAccess () throws IOException {
196+ train (mlFullAccessClient , FunctionName .KMEANS , irisIndex , KMeansParams .builder ().build (), searchSourceBuilder , trainResult -> {
197+ assertFalse (trainResult .containsKey ("model_id" ));
198+ String taskId = (String ) trainResult .get ("task_id" );
199+ assertNotNull (taskId );
200+ String status = (String ) trainResult .get ("status" );
201+ assertEquals (MLTaskState .CREATED .name (), status );
202+ try {
203+ getTask (mlFullAccessClient , taskId , task -> {
204+ String algorithm = (String ) task .get ("function_name" );
205+ assertEquals (FunctionName .KMEANS .name (), algorithm );
206+ });
207+ } catch (IOException e ) {
208+ assertNull (e );
209+ }
210+ }, true );
211+ }
212+
213+ public void testReadOnlyUser_CanGetModel_CanNotDeleteModel () throws IOException {
214+ KMeansParams kMeansParams = KMeansParams .builder ().build ();
215+ // train model with full access client
216+ train (mlFullAccessClient , FunctionName .KMEANS , irisIndex , kMeansParams , searchSourceBuilder , trainResult -> {
217+ String modelId = (String ) trainResult .get ("model_id" );
218+ assertNotNull (modelId );
219+ String status = (String ) trainResult .get ("status" );
220+ assertEquals (MLTaskState .COMPLETED .name (), status );
221+ try {
222+ // get model with readonly client
223+ getModel (mlReadOnlyClient , modelId , model -> {
224+ String algorithm = (String ) model .get ("algorithm" );
225+ assertEquals (FunctionName .KMEANS .name (), algorithm );
226+ });
227+ } catch (IOException e ) {
228+ assertNull (e );
229+ }
230+ try {
231+ // Failed to delete model with read only client
232+ deleteModel (mlReadOnlyClient , modelId , null );
233+ throw new RuntimeException ("Delete model for readonly user does not fail" );
234+ } catch (Exception e ) {
235+ assertEquals (ResponseException .class , e .getClass ());
236+ assertTrue (Throwables .getStackTraceAsString (e ).contains ("no permissions for [cluster:admin/opensearch/ml/models/delete]" ));
237+ }
238+ }, false );
239+ }
240+
241+ public void testReadOnlyUser_CanGetTask_CanNotDeleteTask () throws IOException {
242+ KMeansParams kMeansParams = KMeansParams .builder ().build ();
243+ // train model with full access client
244+ train (mlFullAccessClient , FunctionName .KMEANS , irisIndex , kMeansParams , searchSourceBuilder , trainResult -> {
245+ assertFalse (trainResult .containsKey ("model_id" ));
246+ String taskId = (String ) trainResult .get ("task_id" );
247+ assertNotNull (taskId );
248+ String status = (String ) trainResult .get ("status" );
249+ assertEquals (MLTaskState .CREATED .name (), status );
250+ try {
251+ // get task with readonly client
252+ getTask (mlReadOnlyClient , taskId , task -> {
253+ String algorithm = (String ) task .get ("function_name" );
254+ assertEquals (FunctionName .KMEANS .name (), algorithm );
255+ });
256+ } catch (IOException e ) {
257+ assertNull (e );
258+ }
259+ try {
260+ // Failed to delete task with read only client
261+ deleteTask (mlReadOnlyClient , taskId , null );
262+ throw new RuntimeException ("Delete task for readonly user does not fail" );
263+ } catch (Exception e ) {
264+ assertEquals (ResponseException .class , e .getClass ());
265+ assertTrue (Throwables .getStackTraceAsString (e ).contains ("no permissions for [cluster:admin/opensearch/ml/tasks/delete]" ));
266+ }
267+ }, true );
268+ }
269+
270+ public void testReadOnlyUser_CanSearchModels () throws IOException {
271+ KMeansParams kMeansParams = KMeansParams .builder ().build ();
272+ // train model with full access client
273+ train (mlFullAccessClient , FunctionName .KMEANS , irisIndex , kMeansParams , searchSourceBuilder , trainResult -> {
274+ String modelId = (String ) trainResult .get ("model_id" );
275+ assertNotNull (modelId );
276+ String status = (String ) trainResult .get ("status" );
277+ assertEquals (MLTaskState .COMPLETED .name (), status );
278+ try {
279+ // search model with readonly client
280+ searchModelsWithAlgoName (mlReadOnlyClient , FunctionName .KMEANS .name (), models -> {
281+ ArrayList <Object > hits = (ArrayList ) ((Map <String , Object >) models .get ("hits" )).get ("hits" );
282+ assertTrue (hits .size () > 0 );
283+ });
284+ } catch (IOException e ) {
285+ assertNull (e );
286+ }
287+ }, false );
288+ }
289+
290+ public void testReadOnlyUser_CanSearchTasks () throws IOException {
291+ KMeansParams kMeansParams = KMeansParams .builder ().build ();
292+ // train model with full access client
293+ train (mlFullAccessClient , FunctionName .KMEANS , irisIndex , kMeansParams , searchSourceBuilder , trainResult -> {
294+ assertFalse (trainResult .containsKey ("model_id" ));
295+ String taskId = (String ) trainResult .get ("task_id" );
296+ assertNotNull (taskId );
297+ String status = (String ) trainResult .get ("status" );
298+ assertEquals (MLTaskState .CREATED .name (), status );
299+ try {
300+ // search tasks with readonly client
301+ searchTasksWithAlgoName (mlReadOnlyClient , FunctionName .KMEANS .name (), tasks -> {
302+ ArrayList <Object > hits = (ArrayList ) ((Map <String , Object >) tasks .get ("hits" )).get ("hits" );
303+ assertTrue (hits .size () > 0 );
304+ });
305+ } catch (IOException e ) {
306+ assertNull (e );
307+ }
308+ }, true );
309+ }
145310}
0 commit comments