66package org .opensearch .ml .rest ;
77
88import java .io .IOException ;
9+ import java .text .ParseException ;
910import java .util .List ;
1011import java .util .Map ;
1112
13+ import org .junit .After ;
14+ import org .junit .Before ;
1215import org .opensearch .client .Response ;
1316import org .opensearch .ml .utils .TestHelper ;
1417
1518public class RestMLFlowAgentIT extends MLCommonsRestTestCase {
1619
20+ private String irisIndex = "iris_data" ;
21+
22+ @ Before
23+ public void setup () throws IOException , ParseException {
24+ ingestIrisData (irisIndex );
25+ }
26+
27+ @ After
28+ public void deleteIndices () throws IOException {
29+ deleteIndexWithAdminClient (irisIndex );
30+ }
31+
1732 public void testAgentCatIndexTool () throws IOException {
1833 // Register agent with CatIndexTool.
1934 Response response = registerAgentWithCatIndexTool ();
@@ -35,6 +50,27 @@ public void testAgentCatIndexTool() throws IOException {
3550 assertTrue (result .contains (".plugins-ml-agent" ));
3651 }
3752
53+ public void testAgentSearchIndexTool () throws IOException {
54+ // Register agent with SearchIndexTool.
55+ Response response = registerAgentWithSearchIndexTool ();
56+ Map responseMap = parseResponseToMap (response );
57+ String agentId = (String ) responseMap .get ("agent_id" );
58+ assertNotNull (agentId );
59+ assertEquals (20 , agentId .length ());
60+
61+ // Execute agent.
62+ response = executeAgentSearchIndexTool (agentId );
63+ responseMap = parseResponseToMap (response );
64+ List responseList = (List ) responseMap .get ("inference_results" );
65+ responseMap = (Map ) responseList .get (0 );
66+ responseList = (List ) responseMap .get ("output" );
67+ responseMap = (Map ) responseList .get (0 );
68+ assertEquals ("response" , responseMap .get ("name" ));
69+ String result = (String ) responseMap .get ("result" );
70+ assertNotNull (result );
71+ assertTrue (result .contains ("\" _source\" :{\" petal_length_in_cm\" " ));
72+ }
73+
3874 public static Response registerAgentWithCatIndexTool () throws IOException {
3975 String registerAgentEntity = "{\n "
4076 + " \" name\" : \" Test_Agent_For_CatIndex_tool\" ,\n "
@@ -54,20 +90,50 @@ public static Response registerAgentWithCatIndexTool() throws IOException {
5490 .makeRequest (client (), "POST" , "/_plugins/_ml/agents/_register" , null , TestHelper .toHttpEntity (registerAgentEntity ), null );
5591 }
5692
93+ public static Response registerAgentWithSearchIndexTool () throws IOException {
94+ String registerAgentEntity = "{\n "
95+ + " \" name\" : \" Test_Agent_For_SearchIndex_tool\" ,\n "
96+ + " \" type\" : \" flow\" ,\n "
97+ + " \" description\" : \" this is a test agent for the SearchIndexTool\" ,\n "
98+ + " \" tools\" : [\n "
99+ + " {\n "
100+ + " \" type\" : \" SearchIndexTool\" "
101+ + " }\n "
102+ + " ]\n "
103+ + "}" ;
104+ return TestHelper
105+ .makeRequest (client (), "POST" , "/_plugins/_ml/agents/_register" , null , TestHelper .toHttpEntity (registerAgentEntity ), null );
106+ }
107+
57108 public static Response executeAgentCatIndexTool (String agentId ) throws IOException {
58- String question = "How many indices do I have?" ;
59- return executeAgent (agentId , question );
109+ String question = "\" How many indices do I have?\" " ;
110+ return executeAgent (agentId , Map . of ( " question" , question ) );
60111 }
61112
62- public static Response executeAgent (String agentId , String question ) throws IOException {
63- String executeAgentEntity = "{\n " + " \" parameters\" : {\n " + " \" question\" : \" " + question + " \" \n " + " }\n " + "}" ;
113+ public static Response executeAgentSearchIndexTool (String agentId ) throws IOException {
114+ String input = "{\" index\" : \" iris_data\" , \" query\" : {\" size\" : 2, \" _source\" : \" petal_length_in_cm\" }}" ;
115+ return executeAgent (agentId , Map .of ("input" , input ));
116+ }
117+
118+ public static Response executeAgent (String agentId , Map <String , String > args ) throws IOException {
119+ if (args == null || args .isEmpty ()) {
120+ return null ;
121+ }
122+
123+ // Construct parameters.
124+ StringBuilder entityBuilder = new StringBuilder ("{\" parameters\" :{" );
125+ for (Map .Entry entry : args .entrySet ()) {
126+ entityBuilder .append ('"' ).append (entry .getKey ()).append ("\" :" ).append (entry .getValue ()).append (',' );
127+ }
128+ entityBuilder .replace (entityBuilder .length () - 1 , entityBuilder .length (), "}}" );
129+
64130 return TestHelper
65131 .makeRequest (
66132 client (),
67133 "POST" ,
68134 String .format ("/_plugins/_ml/agents/%s/_execute" , agentId ),
69135 null ,
70- TestHelper .toHttpEntity (executeAgentEntity ),
136+ TestHelper .toHttpEntity (entityBuilder . toString () ),
71137 null
72138 );
73139 }
0 commit comments