88import static org .junit .Assert .assertEquals ;
99import static org .junit .Assert .assertFalse ;
1010import static org .junit .Assert .assertTrue ;
11+ import static org .junit .Assert .fail ;
1112import static org .mockito .Answers .RETURNS_DEEP_STUBS ;
1213import static org .mockito .ArgumentMatchers .any ;
1314import static org .mockito .ArgumentMatchers .eq ;
@@ -325,6 +326,64 @@ public void train() {
325326 assertEquals (status , ((MLTrainingOutput ) argumentCaptor .getValue ()).getStatus ());
326327 }
327328
329+ @ Test
330+ public void getModel_withTenantId () {
331+ String modelContent = "test content" ;
332+ String tenantId = "tenantId" ;
333+ doAnswer (invocation -> {
334+ ActionListener <MLModelGetResponse > actionListener = invocation .getArgument (2 );
335+ MLModel mlModel = MLModel .builder ().algorithm (FunctionName .KMEANS ).name ("test" ).content (modelContent ).build ();
336+ MLModelGetResponse output = MLModelGetResponse .builder ().mlModel (mlModel ).build ();
337+ actionListener .onResponse (output );
338+ return null ;
339+ }).when (client ).execute (eq (MLModelGetAction .INSTANCE ), any (), any ());
340+
341+ ArgumentCaptor <MLModel > argumentCaptor = ArgumentCaptor .forClass (MLModel .class );
342+ machineLearningNodeClient .getModel ("modelId" , tenantId , getModelActionListener );
343+
344+ verify (client ).execute (eq (MLModelGetAction .INSTANCE ), isA (MLModelGetRequest .class ), any ());
345+ verify (getModelActionListener ).onResponse (argumentCaptor .capture ());
346+ assertEquals (FunctionName .KMEANS , argumentCaptor .getValue ().getAlgorithm ());
347+ assertEquals (modelContent , argumentCaptor .getValue ().getContent ());
348+ }
349+
350+ @ Test
351+ public void undeployModels_withNullNodeIds () {
352+ doAnswer (invocation -> {
353+ ActionListener <MLUndeployModelsResponse > actionListener = invocation .getArgument (2 );
354+ MLUndeployModelsResponse output = new MLUndeployModelsResponse (
355+ new MLUndeployModelNodesResponse (ClusterName .DEFAULT , Collections .emptyList (), Collections .emptyList ())
356+ );
357+ actionListener .onResponse (output );
358+ return null ;
359+ }).when (client ).execute (eq (MLUndeployModelsAction .INSTANCE ), any (), any ());
360+
361+ machineLearningNodeClient .undeploy (new String [] { "model1" }, null , undeployModelsActionListener );
362+ verify (client ).execute (eq (MLUndeployModelsAction .INSTANCE ), isA (MLUndeployModelsRequest .class ), any ());
363+ }
364+
365+ @ Test
366+ public void createConnector_withValidInput () {
367+ doAnswer (invocation -> {
368+ ActionListener <MLCreateConnectorResponse > actionListener = invocation .getArgument (2 );
369+ MLCreateConnectorResponse output = new MLCreateConnectorResponse ("connectorId" );
370+ actionListener .onResponse (output );
371+ return null ;
372+ }).when (client ).execute (eq (MLCreateConnectorAction .INSTANCE ), any (), any ());
373+
374+ MLCreateConnectorInput input = MLCreateConnectorInput
375+ .builder ()
376+ .name ("testConnector" )
377+ .protocol ("http" )
378+ .version ("1" )
379+ .credential (Map .of ("TEST_CREDENTIAL_KEY" , "TEST_CREDENTIAL_VALUE" ))
380+ .parameters (Map .of ("endpoint" , "https://example.com" ))
381+ .build ();
382+
383+ machineLearningNodeClient .createConnector (input , createConnectorActionListener );
384+ verify (client ).execute (eq (MLCreateConnectorAction .INSTANCE ), isA (MLCreateConnectorRequest .class ), any ());
385+ }
386+
328387 @ Test
329388 public void registerModelGroup_withValidInput () {
330389 doAnswer (invocation -> {
@@ -346,6 +405,146 @@ public void registerModelGroup_withValidInput() {
346405 verify (client ).execute (eq (MLRegisterModelGroupAction .INSTANCE ), isA (MLRegisterModelGroupRequest .class ), any ());
347406 }
348407
408+ @ Test
409+ public void listTools_withValidRequest () {
410+ doAnswer (invocation -> {
411+ ActionListener <MLToolsListResponse > actionListener = invocation .getArgument (2 );
412+ MLToolsListResponse output = MLToolsListResponse
413+ .builder ()
414+ .toolMetadata (
415+ Arrays
416+ .asList (
417+ ToolMetadata .builder ().name ("tool1" ).description ("description1" ).build (),
418+ ToolMetadata .builder ().name ("tool2" ).description ("description2" ).build ()
419+ )
420+ )
421+ .build ();
422+ actionListener .onResponse (output );
423+ return null ;
424+ }).when (client ).execute (eq (MLListToolsAction .INSTANCE ), any (), any ());
425+
426+ machineLearningNodeClient .listTools (listToolsActionListener );
427+ verify (client ).execute (eq (MLListToolsAction .INSTANCE ), isA (MLToolsListRequest .class ), any ());
428+ }
429+
430+ @ Test
431+ public void listTools_withEmptyResponse () {
432+ doAnswer (invocation -> {
433+ ActionListener <MLToolsListResponse > actionListener = invocation .getArgument (2 );
434+ MLToolsListResponse output = MLToolsListResponse .builder ().toolMetadata (Collections .emptyList ()).build ();
435+ actionListener .onResponse (output );
436+ return null ;
437+ }).when (client ).execute (eq (MLListToolsAction .INSTANCE ), any (), any ());
438+
439+ ArgumentCaptor <List <ToolMetadata >> argumentCaptor = ArgumentCaptor .forClass (List .class );
440+ machineLearningNodeClient .listTools (listToolsActionListener );
441+
442+ verify (client ).execute (eq (MLListToolsAction .INSTANCE ), isA (MLToolsListRequest .class ), any ());
443+ verify (listToolsActionListener ).onResponse (argumentCaptor .capture ());
444+
445+ List <ToolMetadata > capturedTools = argumentCaptor .getValue ();
446+ assertTrue (capturedTools .isEmpty ());
447+ }
448+
449+ @ Test
450+ public void getTool_withValidToolName () {
451+ doAnswer (invocation -> {
452+ ActionListener <MLToolGetResponse > actionListener = invocation .getArgument (2 );
453+ MLToolGetResponse output = MLToolGetResponse
454+ .builder ()
455+ .toolMetadata (ToolMetadata .builder ().name ("tool1" ).description ("description1" ).build ())
456+ .build ();
457+ actionListener .onResponse (output );
458+ return null ;
459+ }).when (client ).execute (eq (MLGetToolAction .INSTANCE ), any (), any ());
460+
461+ machineLearningNodeClient .getTool ("tool1" , getToolActionListener );
462+ verify (client ).execute (eq (MLGetToolAction .INSTANCE ), isA (MLToolGetRequest .class ), any ());
463+ }
464+
465+ @ Test
466+ public void getTool_withValidRequest () {
467+ ToolMetadata toolMetadata = ToolMetadata
468+ .builder ()
469+ .name ("MathTool" )
470+ .description ("Use this tool to calculate any math problem." )
471+ .build ();
472+
473+ doAnswer (invocation -> {
474+ ActionListener <MLToolGetResponse > actionListener = invocation .getArgument (2 );
475+ MLToolGetResponse output = MLToolGetResponse .builder ().toolMetadata (toolMetadata ).build ();
476+ actionListener .onResponse (output );
477+ return null ;
478+ }).when (client ).execute (eq (MLGetToolAction .INSTANCE ), any (), any ());
479+
480+ ArgumentCaptor <ToolMetadata > argumentCaptor = ArgumentCaptor .forClass (ToolMetadata .class );
481+ machineLearningNodeClient .getTool ("MathTool" , getToolActionListener );
482+
483+ verify (client ).execute (eq (MLGetToolAction .INSTANCE ), isA (MLToolGetRequest .class ), any ());
484+ verify (getToolActionListener ).onResponse (argumentCaptor .capture ());
485+
486+ ToolMetadata capturedTool = argumentCaptor .getValue ();
487+ assertEquals ("MathTool" , capturedTool .getName ());
488+ assertEquals ("Use this tool to calculate any math problem." , capturedTool .getDescription ());
489+ }
490+
491+ @ Test
492+ public void getTool_withFailureResponse () {
493+ doAnswer (invocation -> {
494+ ActionListener <MLToolGetResponse > actionListener = invocation .getArgument (2 );
495+ actionListener .onFailure (new RuntimeException ("Test exception" ));
496+ return null ;
497+ }).when (client ).execute (eq (MLGetToolAction .INSTANCE ), any (), any ());
498+
499+ machineLearningNodeClient .getTool ("MathTool" , new ActionListener <>() {
500+ @ Override
501+ public void onResponse (ToolMetadata toolMetadata ) {
502+ fail ("Expected failure but got response" );
503+ }
504+
505+ @ Override
506+ public void onFailure (Exception e ) {
507+ assertEquals ("Test exception" , e .getMessage ());
508+ }
509+ });
510+
511+ verify (client ).execute (eq (MLGetToolAction .INSTANCE ), isA (MLToolGetRequest .class ), any ());
512+ }
513+
514+ @ Test
515+ public void train_withAsync () {
516+ doAnswer (invocation -> {
517+ ActionListener <MLTaskResponse > actionListener = invocation .getArgument (2 );
518+ MLTrainingOutput output = MLTrainingOutput .builder ().status ("InProgress" ).modelId ("modelId" ).build ();
519+ actionListener .onResponse (MLTaskResponse .builder ().output (output ).build ());
520+ return null ;
521+ }).when (client ).execute (eq (MLTrainingTaskAction .INSTANCE ), any (), any ());
522+
523+ MLInput mlInput = MLInput .builder ().algorithm (FunctionName .KMEANS ).inputDataset (input ).build ();
524+ machineLearningNodeClient .train (mlInput , true , trainingActionListener );
525+ verify (client ).execute (eq (MLTrainingTaskAction .INSTANCE ), isA (MLTrainingTaskRequest .class ), any ());
526+ }
527+
528+ @ Test
529+ public void deleteModel_withTenantId () {
530+ String modelId = "testModelId" ;
531+ String tenantId = "tenantId" ;
532+ doAnswer (invocation -> {
533+ ActionListener <DeleteResponse > actionListener = invocation .getArgument (2 );
534+ ShardId shardId = new ShardId (new Index ("indexName" , "uuid" ), 1 );
535+ DeleteResponse output = new DeleteResponse (shardId , modelId , 1 , 1 , 1 , true );
536+ actionListener .onResponse (output );
537+ return null ;
538+ }).when (client ).execute (eq (MLModelDeleteAction .INSTANCE ), any (), any ());
539+
540+ ArgumentCaptor <DeleteResponse > argumentCaptor = ArgumentCaptor .forClass (DeleteResponse .class );
541+ machineLearningNodeClient .deleteModel (modelId , tenantId , deleteModelActionListener );
542+
543+ verify (client ).execute (eq (MLModelDeleteAction .INSTANCE ), isA (MLModelDeleteRequest .class ), any ());
544+ verify (deleteModelActionListener ).onResponse (argumentCaptor .capture ());
545+ assertEquals (modelId , argumentCaptor .getValue ().getId ());
546+ }
547+
349548 @ Test
350549 public void train_Exception_WithNullDataSet () {
351550 exceptionRule .expect (IllegalArgumentException .class );
0 commit comments