@@ -884,7 +884,7 @@ public void deleteTask() {
884884 }).when (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), any (), any ());
885885
886886 ArgumentCaptor <DeleteResponse > argumentCaptor = ArgumentCaptor .forClass (DeleteResponse .class );
887- machineLearningNodeClient .deleteTask (taskId , deleteTaskActionListener );
887+ machineLearningNodeClient .deleteTask (taskId , null , deleteTaskActionListener );
888888
889889 verify (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), isA (MLTaskDeleteRequest .class ), any ());
890890 verify (deleteTaskActionListener ).onResponse (argumentCaptor .capture ());
@@ -1276,6 +1276,185 @@ public void getConfigRejectedMasterKey() {
12761276 assertEquals ("You are not allowed to access this config doc" , argumentCaptor .getValue ().getLocalizedMessage ());
12771277 }
12781278
1279+ @ Test
1280+ public void predict_withTenantId () {
1281+ String tenantId = "testTenant" ;
1282+ doAnswer (invocation -> {
1283+ ActionListener <MLTaskResponse > actionListener = invocation .getArgument (2 );
1284+ MLPredictionOutput predictionOutput = MLPredictionOutput
1285+ .builder ()
1286+ .status ("Success" )
1287+ .predictionResult (output )
1288+ .taskId ("taskId" )
1289+ .build ();
1290+ actionListener .onResponse (MLTaskResponse .builder ().output (predictionOutput ).build ());
1291+ return null ;
1292+ }).when (client ).execute (eq (MLPredictionTaskAction .INSTANCE ), any (), any ());
1293+
1294+ ArgumentCaptor <MLPredictionTaskRequest > requestCaptor = ArgumentCaptor .forClass (MLPredictionTaskRequest .class );
1295+ MLInput mlInput = MLInput .builder ().algorithm (FunctionName .KMEANS ).inputDataset (input ).build ();
1296+ machineLearningNodeClient .predict ("modelId" , tenantId , mlInput , dataFrameActionListener );
1297+
1298+ verify (client ).execute (eq (MLPredictionTaskAction .INSTANCE ), requestCaptor .capture (), any ());
1299+ assertEquals (tenantId , requestCaptor .getValue ().getTenantId ());
1300+ assertEquals ("modelId" , requestCaptor .getValue ().getModelId ());
1301+ }
1302+
1303+ @ Test
1304+ public void getTask_withFailure () {
1305+ String taskId = "taskId" ;
1306+ String errorMessage = "Task not found" ;
1307+
1308+ doAnswer (invocation -> {
1309+ ActionListener <MLTaskGetResponse > actionListener = invocation .getArgument (2 );
1310+ actionListener .onFailure (new IllegalArgumentException (errorMessage ));
1311+ return null ;
1312+ }).when (client ).execute (eq (MLTaskGetAction .INSTANCE ), any (), any ());
1313+
1314+ ArgumentCaptor <Exception > exceptionCaptor = ArgumentCaptor .forClass (Exception .class );
1315+
1316+ machineLearningNodeClient .getTask (taskId , new ActionListener <>() {
1317+ @ Override
1318+ public void onResponse (MLTask mlTask ) {
1319+ fail ("Expected failure but got success" );
1320+ }
1321+
1322+ @ Override
1323+ public void onFailure (Exception e ) {
1324+ assertEquals (errorMessage , e .getMessage ());
1325+ }
1326+ });
1327+
1328+ verify (client ).execute (eq (MLTaskGetAction .INSTANCE ), isA (MLTaskGetRequest .class ), any ());
1329+ }
1330+
1331+ @ Test
1332+ public void deploy_withTenantId () {
1333+ String modelId = "testModel" ;
1334+ String tenantId = "testTenant" ;
1335+ String taskId = "taskId" ;
1336+ String status = MLTaskState .CREATED .name ();
1337+
1338+ doAnswer (invocation -> {
1339+ ActionListener <MLDeployModelResponse > actionListener = invocation .getArgument (2 );
1340+ MLDeployModelResponse output = new MLDeployModelResponse (taskId , MLTaskType .DEPLOY_MODEL , status );
1341+ actionListener .onResponse (output );
1342+ return null ;
1343+ }).when (client ).execute (eq (MLDeployModelAction .INSTANCE ), any (), any ());
1344+
1345+ ArgumentCaptor <MLDeployModelRequest > requestCaptor = ArgumentCaptor .forClass (MLDeployModelRequest .class );
1346+ machineLearningNodeClient .deploy (modelId , tenantId , deployModelActionListener );
1347+
1348+ verify (client ).execute (eq (MLDeployModelAction .INSTANCE ), requestCaptor .capture (), any ());
1349+ assertEquals (modelId , requestCaptor .getValue ().getModelId ());
1350+ assertEquals (tenantId , requestCaptor .getValue ().getTenantId ());
1351+ }
1352+
1353+ @ Test
1354+ public void trainAndPredict_withNullInput () {
1355+ exceptionRule .expect (IllegalArgumentException .class );
1356+ exceptionRule .expectMessage ("ML Input can't be null" );
1357+
1358+ machineLearningNodeClient .trainAndPredict (null , trainingActionListener );
1359+ }
1360+
1361+ @ Test
1362+ public void trainAndPredict_withNullDataSet () {
1363+ exceptionRule .expect (IllegalArgumentException .class );
1364+ exceptionRule .expectMessage ("input data set can't be null" );
1365+
1366+ MLInput mlInput = MLInput .builder ().algorithm (FunctionName .KMEANS ).build ();
1367+ machineLearningNodeClient .trainAndPredict (mlInput , trainingActionListener );
1368+ }
1369+
1370+ @ Test
1371+ public void getTask_withTaskIdAndTenantId () {
1372+ String taskId = "taskId" ;
1373+ String tenantId = "testTenant" ;
1374+ String modelId = "modelId" ;
1375+
1376+ doAnswer (invocation -> {
1377+ ActionListener <MLTaskGetResponse > actionListener = invocation .getArgument (2 );
1378+ MLTask mlTask = MLTask .builder ().taskId (taskId ).modelId (modelId ).functionName (FunctionName .KMEANS ).build ();
1379+ MLTaskGetResponse output = MLTaskGetResponse .builder ().mlTask (mlTask ).build ();
1380+ actionListener .onResponse (output );
1381+ return null ;
1382+ }).when (client ).execute (eq (MLTaskGetAction .INSTANCE ), any (), any ());
1383+
1384+ ArgumentCaptor <MLTaskGetRequest > requestCaptor = ArgumentCaptor .forClass (MLTaskGetRequest .class );
1385+ ArgumentCaptor <MLTask > taskCaptor = ArgumentCaptor .forClass (MLTask .class );
1386+
1387+ machineLearningNodeClient .getTask (taskId , tenantId , getTaskActionListener );
1388+
1389+ verify (client ).execute (eq (MLTaskGetAction .INSTANCE ), requestCaptor .capture (), any ());
1390+ verify (getTaskActionListener ).onResponse (taskCaptor .capture ());
1391+
1392+ // Verify request parameters
1393+ assertEquals (taskId , requestCaptor .getValue ().getTaskId ());
1394+ assertEquals (tenantId , requestCaptor .getValue ().getTenantId ());
1395+
1396+ // Verify response
1397+ assertEquals (taskId , taskCaptor .getValue ().getTaskId ());
1398+ assertEquals (modelId , taskCaptor .getValue ().getModelId ());
1399+ assertEquals (FunctionName .KMEANS , taskCaptor .getValue ().getFunctionName ());
1400+ }
1401+
1402+ @ Test
1403+ public void deleteTask_withTaskId () {
1404+ String taskId = "taskId" ;
1405+
1406+ doAnswer (invocation -> {
1407+ ActionListener <DeleteResponse > actionListener = invocation .getArgument (2 );
1408+ ShardId shardId = new ShardId (new Index ("indexName" , "uuid" ), 1 );
1409+ DeleteResponse output = new DeleteResponse (shardId , taskId , 1 , 1 , 1 , true );
1410+ actionListener .onResponse (output );
1411+ return null ;
1412+ }).when (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), any (), any ());
1413+
1414+ ArgumentCaptor <MLTaskDeleteRequest > requestCaptor = ArgumentCaptor .forClass (MLTaskDeleteRequest .class );
1415+ ArgumentCaptor <DeleteResponse > responseCaptor = ArgumentCaptor .forClass (DeleteResponse .class );
1416+
1417+ machineLearningNodeClient .deleteTask (taskId , deleteTaskActionListener );
1418+
1419+ verify (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), requestCaptor .capture (), any ());
1420+ verify (deleteTaskActionListener ).onResponse (responseCaptor .capture ());
1421+
1422+ // Verify request parameter
1423+ assertEquals (taskId , requestCaptor .getValue ().getTaskId ());
1424+
1425+ // Verify response
1426+ assertEquals (taskId , responseCaptor .getValue ().getId ());
1427+ assertEquals ("DELETED" , responseCaptor .getValue ().getResult ().toString ());
1428+ }
1429+
1430+ @ Test
1431+ public void deleteTask_withFailure () {
1432+ String taskId = "taskId" ;
1433+ String errorMessage = "Task deletion failed" ;
1434+
1435+ doAnswer (invocation -> {
1436+ ActionListener <DeleteResponse > actionListener = invocation .getArgument (2 );
1437+ actionListener .onFailure (new RuntimeException (errorMessage ));
1438+ return null ;
1439+ }).when (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), any (), any ());
1440+
1441+ ArgumentCaptor <Exception > exceptionCaptor = ArgumentCaptor .forClass (Exception .class );
1442+
1443+ machineLearningNodeClient .deleteTask (taskId , new ActionListener <>() {
1444+ @ Override
1445+ public void onResponse (DeleteResponse deleteResponse ) {
1446+ fail ("Expected failure but got success" );
1447+ }
1448+
1449+ @ Override
1450+ public void onFailure (Exception e ) {
1451+ assertEquals (errorMessage , e .getMessage ());
1452+ }
1453+ });
1454+
1455+ verify (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), isA (MLTaskDeleteRequest .class ), any ());
1456+ }
1457+
12791458 private SearchResponse createSearchResponse (ToXContentObject o ) throws IOException {
12801459 XContentBuilder content = o .toXContent (XContentFactory .jsonBuilder (), ToXContent .EMPTY_PARAMS );
12811460
0 commit comments