@@ -740,120 +740,6 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti
740740 assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
741741 }
742742
743- public void testCohereClassifyModel () throws IOException , InterruptedException {
744- // Skip test if key is null
745- if (COHERE_KEY == null ) {
746- return ;
747- }
748- String entity = "{\n "
749- + " \" name\" : \" Cohere classify model Connector\" ,\n "
750- + " \" description\" : \" The connector to public Cohere classify model service\" ,\n "
751- + " \" version\" : 1,\n "
752- + " \" client_config\" : {\n "
753- + " \" max_connection\" : 20,\n "
754- + " \" connection_timeout\" : 50000,\n "
755- + " \" read_timeout\" : 50000\n "
756- + " },\n "
757- + " \" protocol\" : \" http\" ,\n "
758- + " \" parameters\" : {\n "
759- + " \" endpoint\" : \" api.cohere.ai\" ,\n "
760- + " \" auth\" : \" API_Key\" ,\n "
761- + " \" content_type\" : \" application/json\" ,\n "
762- + " \" max_tokens\" : \" 20\" \n "
763- + " },\n "
764- + " \" credential\" : {\n "
765- + " \" cohere_key\" : \" "
766- + COHERE_KEY
767- + "\" \n "
768- + " },\n "
769- + " \" actions\" : [\n "
770- + " {\n "
771- + " \" action_type\" : \" predict\" ,\n "
772- + " \" method\" : \" POST\" ,\n "
773- + " \" url\" : \" https://${parameters.endpoint}/v1/classify\" ,\n "
774- + " \" headers\" : { \n "
775- + " \" Authorization\" : \" Bearer ${credential.cohere_key}\" \n "
776- + " },\n "
777- + " \" request_body\" : \" { \\ \" inputs\\ \" : ${parameters.inputs}, \\ \" examples\\ \" : ${parameters.examples}, \\ \" truncate\\ \" : \\ \" END\\ \" }\" \n "
778- + " }\n "
779- + " ]\n "
780- + "}" ;
781- Response response = createConnector (entity );
782- Map responseMap = parseResponseToMap (response );
783- String connectorId = (String ) responseMap .get ("connector_id" );
784- response = registerRemoteModel ("cohere classify model" , connectorId );
785- responseMap = parseResponseToMap (response );
786- String taskId = (String ) responseMap .get ("task_id" );
787- waitForTask (taskId , MLTaskState .COMPLETED );
788- response = getTask (taskId );
789- responseMap = parseResponseToMap (response );
790- String modelId = (String ) responseMap .get ("model_id" );
791- response = deployRemoteModel (modelId );
792- responseMap = parseResponseToMap (response );
793- taskId = (String ) responseMap .get ("task_id" );
794- waitForTask (taskId , MLTaskState .COMPLETED );
795- String predictInput = "{\n "
796- + " \" parameters\" : {\n "
797- + " \" inputs\" : [\n "
798- + " \" Confirm your email address\" ,\n "
799- + " \" hey i need u to send some $\" \n "
800- + " ],\n "
801- + " \" examples\" : [\n "
802- + " {\n "
803- + " \" text\" : \" Dermatologists don't like her!\" ,\n "
804- + " \" label\" : \" Spam\" \n "
805- + " },\n "
806- + " {\n "
807- + " \" text\" : \" Hello, open to this?\" ,\n "
808- + " \" label\" : \" Spam\" \n "
809- + " },\n "
810- + " {\n "
811- + " \" text\" : \" I need help please wire me $1000 right now\" ,\n "
812- + " \" label\" : \" Spam\" \n "
813- + " },\n "
814- + " {\n "
815- + " \" text\" : \" Nice to know you ;)\" ,\n "
816- + " \" label\" : \" Spam\" \n "
817- + " },\n "
818- + " {\n "
819- + " \" text\" : \" Please help me?\" ,\n "
820- + " \" label\" : \" Spam\" \n "
821- + " },\n "
822- + " {\n "
823- + " \" text\" : \" Your parcel will be delivered today\" ,\n "
824- + " \" label\" : \" Not spam\" \n "
825- + " },\n "
826- + " {\n "
827- + " \" text\" : \" Review changes to our Terms and Conditions\" ,\n "
828- + " \" label\" : \" Not spam\" \n "
829- + " },\n "
830- + " {\n "
831- + " \" text\" : \" Weekly sync notes\" ,\n "
832- + " \" label\" : \" Not spam\" \n "
833- + " },\n "
834- + " {\n "
835- + " \" text\" : \" Re: Follow up from todays meeting\" ,\n "
836- + " \" label\" : \" Not spam\" \n "
837- + " },\n "
838- + " {\n "
839- + " \" text\" : \" Pre-read for tomorrow\" ,\n "
840- + " \" label\" : \" Not spam\" \n "
841- + " }\n "
842- + " ]\n "
843- + " }\n "
844- + "}" ;
845-
846- response = predictRemoteModel (modelId , predictInput );
847- responseMap = parseResponseToMap (response );
848- List responseList = (List ) responseMap .get ("inference_results" );
849- responseMap = (Map ) responseList .get (0 );
850- responseList = (List ) responseMap .get ("output" );
851- responseMap = (Map ) responseList .get (0 );
852- responseMap = (Map ) responseMap .get ("dataAsMap" );
853- responseList = (List ) responseMap .get ("classifications" );
854- assertFalse (responseList .isEmpty ());
855- }
856-
857743 public static Response createConnector (String input ) throws IOException {
858744 try {
859745 return TestHelper .makeRequest (client (), "POST" , "/_plugins/_ml/connectors/_create" , null , TestHelper .toHttpEntity (input ), null );
0 commit comments