@@ -237,6 +237,70 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
237237 assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
238238 }
239239
240+ public void testPredictRemoteModelWithInterface (String testCase , Consumer <Map > verifyResponse , Consumer <Exception > verifyException )
241+ throws IOException ,
242+ InterruptedException {
243+ // Skip test if key is null
244+ if (OPENAI_KEY == null ) {
245+ return ;
246+ }
247+ Response response = createConnector (completionModelConnectorEntity );
248+ Map responseMap = parseResponseToMap (response );
249+ String connectorId = (String ) responseMap .get ("connector_id" );
250+ response = registerRemoteModelWithInterface ("openAI-GPT-3.5 completions" , connectorId , testCase );
251+ responseMap = parseResponseToMap (response );
252+ String taskId = (String ) responseMap .get ("task_id" );
253+ waitForTask (taskId , MLTaskState .COMPLETED );
254+ response = getTask (taskId );
255+ responseMap = parseResponseToMap (response );
256+ String modelId = (String ) responseMap .get ("model_id" );
257+ response = deployRemoteModel (modelId );
258+ responseMap = parseResponseToMap (response );
259+ taskId = (String ) responseMap .get ("task_id" );
260+ waitForTask (taskId , MLTaskState .COMPLETED );
261+ String predictInput = "{\n " + " \" parameters\" : {\n " + " \" prompt\" : \" Say this is a test\" \n " + " }\n " + "}" ;
262+ try {
263+ response = predictRemoteModel (modelId , predictInput );
264+ responseMap = parseResponseToMap (response );
265+ verifyResponse .accept (responseMap );
266+ } catch (Exception e ) {
267+ verifyException .accept (e );
268+ }
269+ }
270+
271+ public void testPredictRemoteModelWithCorrectInterface () throws IOException , InterruptedException {
272+ testPredictRemoteModelWithInterface ("correctInterface" , (responseMap ) -> {
273+ List responseList = (List ) responseMap .get ("inference_results" );
274+ responseMap = (Map ) responseList .get (0 );
275+ responseList = (List ) responseMap .get ("output" );
276+ responseMap = (Map ) responseList .get (0 );
277+ responseMap = (Map ) responseMap .get ("dataAsMap" );
278+ responseList = (List ) responseMap .get ("choices" );
279+ if (responseList == null ) {
280+ assertTrue (checkThrottlingOpenAI (responseMap ));
281+ return ;
282+ }
283+ responseMap = (Map ) responseList .get (0 );
284+ assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
285+ }, null );
286+ }
287+
288+ public void testPredictRemoteModelWithWrongInputInterface () throws IOException , InterruptedException {
289+ testPredictRemoteModelWithInterface ("wrongInputInterface" , null , (exception ) -> {
290+ assertTrue (exception instanceof org .opensearch .client .ResponseException );
291+ String stackTrace = ExceptionUtils .getStackTrace (exception );
292+ assertTrue (stackTrace .contains ("Error validating input schema" ));
293+ });
294+ }
295+
296+ public void testPredictRemoteModelWithWrongOutputInterface () throws IOException , InterruptedException {
297+ testPredictRemoteModelWithInterface ("wrongOutputInterface" , null , (exception ) -> {
298+ assertTrue (exception instanceof org .opensearch .client .ResponseException );
299+ String stackTrace = ExceptionUtils .getStackTrace (exception );
300+ assertTrue (stackTrace .contains ("Error validating output schema" ));
301+ });
302+ }
303+
240304 public void testUndeployRemoteModel () throws IOException , InterruptedException {
241305 Response response = createConnector (completionModelConnectorEntity );
242306 Map responseMap = parseResponseToMap (response );
@@ -777,6 +841,183 @@ public static Response registerRemoteModel(String name, String connectorId) thro
777841 .makeRequest (client (), "POST" , "/_plugins/_ml/models/_register" , null , TestHelper .toHttpEntity (registerModelEntity ), null );
778842 }
779843
844+ public static Response registerRemoteModelWithInterface (String name , String connectorId , String testCase ) throws IOException {
845+ String registerModelGroupEntity = "{\n "
846+ + " \" name\" : \" remote_model_group\" ,\n "
847+ + " \" description\" : \" This is an example description\" \n "
848+ + "}" ;
849+ Response response = TestHelper
850+ .makeRequest (
851+ client (),
852+ "POST" ,
853+ "/_plugins/_ml/model_groups/_register" ,
854+ null ,
855+ TestHelper .toHttpEntity (registerModelGroupEntity ),
856+ null
857+ );
858+ Map responseMap = parseResponseToMap (response );
859+ assertEquals ((String ) responseMap .get ("status" ), "CREATED" );
860+ String modelGroupId = (String ) responseMap .get ("model_group_id" );
861+
862+ final String openaiConnectorEntityWithCorrectInterface = "{\n "
863+ + " \" name\" : \" "
864+ + name
865+ + "\" ,\n "
866+ + " \" model_group_id\" : \" "
867+ + modelGroupId
868+ + "\" ,\n "
869+ + " \" function_name\" : \" remote\" ,\n "
870+ + " \" connector_id\" : \" "
871+ + connectorId
872+ + "\" ,\n "
873+ + " \" interface\" : {\n "
874+ + " \" input\" : {\n "
875+ + " \" properties\" : {\n "
876+ + " \" parameters\" : {\n "
877+ + " \" properties\" : {\n "
878+ + " \" prompt\" : {\n "
879+ + " \" type\" : \" string\" ,\n "
880+ + " \" description\" : \" This is a test description field\" \n "
881+ + " }\n "
882+ + " }\n "
883+ + " }\n "
884+ + " }\n "
885+ + " },\n "
886+ + " \" output\" : {\n "
887+ + " \" properties\" : {\n "
888+ + " \" inference_results\" : {\n "
889+ + " \" type\" : \" array\" ,\n "
890+ + " \" items\" : {\n "
891+ + " \" type\" : \" object\" ,\n "
892+ + " \" properties\" : {\n "
893+ + " \" output\" : {\n "
894+ + " \" type\" : \" array\" ,\n "
895+ + " \" items\" : {\n "
896+ + " \" properties\" : {\n "
897+ + " \" name\" : {\n "
898+ + " \" type\" : \" string\" ,\n "
899+ + " \" description\" : \" This is a test description field\" \n "
900+ + " },\n "
901+ + " \" dataAsMap\" : {\n "
902+ + " \" type\" : \" object\" ,\n "
903+ + " \" description\" : \" This is a test description field\" \n "
904+ + " }\n "
905+ + " }\n "
906+ + " },\n "
907+ + " \" description\" : \" This is a test description field\" \n "
908+ + " },\n "
909+ + " \" status_code\" : {\n "
910+ + " \" type\" : \" integer\" ,\n "
911+ + " \" description\" : \" This is a test description field\" \n "
912+ + " }\n "
913+ + " }\n "
914+ + " },\n "
915+ + " \" description\" : \" This is a test description field\" \n "
916+ + " }\n "
917+ + " }\n "
918+ + " }\n "
919+ + " }\n "
920+ + "}" ;
921+
922+ final String openaiConnectorEntityWithWrongInputInterface = "{\n "
923+ + " \" name\" : \" "
924+ + name
925+ + "\" ,\n "
926+ + " \" model_group_id\" : \" "
927+ + modelGroupId
928+ + "\" ,\n "
929+ + " \" function_name\" : \" remote\" ,\n "
930+ + " \" connector_id\" : \" "
931+ + connectorId
932+ + "\" ,\n "
933+ + " \" interface\" : {\n "
934+ + " \" input\" : {\n "
935+ + " \" properties\" : {\n "
936+ + " \" parameters\" : {\n "
937+ + " \" properties\" : {\n "
938+ + " \" prompt\" : {\n "
939+ + " \" type\" : \" integer\" ,\n "
940+ + " \" description\" : \" This is a test description field\" \n "
941+ + " }\n "
942+ + " }\n "
943+ + " }\n "
944+ + " }\n "
945+ + " }\n "
946+ + " }\n "
947+ + "}" ;
948+
949+ final String openaiConnectorEntityWithWrongOutputInterface = "{\n "
950+ + " \" name\" : \" "
951+ + name
952+ + "\" ,\n "
953+ + " \" model_group_id\" : \" "
954+ + modelGroupId
955+ + "\" ,\n "
956+ + " \" function_name\" : \" remote\" ,\n "
957+ + " \" connector_id\" : \" "
958+ + connectorId
959+ + "\" ,\n "
960+ + " \" interface\" : {\n "
961+ + " \" output\" : {\n "
962+ + " \" properties\" : {\n "
963+ + " \" inference_results\" : {\n "
964+ + " \" type\" : \" array\" ,\n "
965+ + " \" items\" : {\n "
966+ + " \" type\" : \" object\" ,\n "
967+ + " \" properties\" : {\n "
968+ + " \" output\" : {\n "
969+ + " \" type\" : \" integer\" ,\n "
970+ + " \" description\" : \" This is a test description field\" \n "
971+ + " },\n "
972+ + " \" status_code\" : {\n "
973+ + " \" type\" : \" integer\" ,\n "
974+ + " \" description\" : \" This is a test description field\" \n "
975+ + " }\n "
976+ + " }\n "
977+ + " },\n "
978+ + " \" description\" : \" This is a test description field\" \n "
979+ + " }\n "
980+ + " }\n "
981+ + " }\n "
982+ + " }\n "
983+ + "}" ;
984+
985+ switch (testCase ) {
986+ case "correctInterface" :
987+ return TestHelper
988+ .makeRequest (
989+ client (),
990+ "POST" ,
991+ "/_plugins/_ml/models/_register" ,
992+ null ,
993+ TestHelper .toHttpEntity (openaiConnectorEntityWithCorrectInterface ),
994+ null
995+ );
996+ case "wrongInputInterface" :
997+ return TestHelper
998+ .makeRequest (
999+ client (),
1000+ "POST" ,
1001+ "/_plugins/_ml/models/_register" ,
1002+ null ,
1003+ TestHelper .toHttpEntity (openaiConnectorEntityWithWrongInputInterface ),
1004+ null
1005+ );
1006+ case "wrongOutputInterface" :
1007+ return TestHelper
1008+ .makeRequest (
1009+ client (),
1010+ "POST" ,
1011+ "/_plugins/_ml/models/_register" ,
1012+ null ,
1013+ TestHelper .toHttpEntity (openaiConnectorEntityWithWrongOutputInterface ),
1014+ null
1015+ );
1016+ default :
1017+ throw new IllegalArgumentException ("Invalid test case" );
1018+ }
1019+ }
1020+
7801021 public static Response deployRemoteModel (String modelId ) throws IOException {
7811022 return TestHelper .makeRequest (client (), "POST" , "/_plugins/_ml/models/" + modelId + "/_deploy" , null , "" , null );
7821023 }
@@ -831,4 +1072,15 @@ public String registerRemoteModel() throws IOException {
8311072 logger .info ("task ID created: {}" , taskId );
8321073 return taskId ;
8331074 }
1075+
1076+ public String registerRemoteModelWithInterface (String testCase ) throws IOException {
1077+ Response response = createConnector (completionModelConnectorEntity );
1078+ Map responseMap = parseResponseToMap (response );
1079+ String connectorId = (String ) responseMap .get ("connector_id" );
1080+ response = registerRemoteModelWithInterface ("openAI-GPT-3.5 completions" , connectorId , testCase );
1081+ responseMap = parseResponseToMap (response );
1082+ String taskId = (String ) responseMap .get ("task_id" );
1083+ logger .info ("task ID created: {}" , taskId );
1084+ return taskId ;
1085+ }
8341086}
0 commit comments