@@ -670,76 +670,6 @@ private void testOpenAITextEmbeddingModel(String charset, Consumer<Map> verifyRe
670670 }
671671 }
672672
673- public void testCohereGenerateTextModel () throws IOException , InterruptedException {
674- // Skip test if key is null
675- if (COHERE_KEY == null ) {
676- return ;
677- }
678- String entity = "{\n "
679- + " \" name\" : \" Cohere generate text model Connector\" ,\n "
680- + " \" description\" : \" The connector to public Cohere generate text model service\" ,\n "
681- + " \" version\" : 1,\n "
682- + "\" client_config\" : {\n "
683- + " \" max_connection\" : 20,\n "
684- + " \" connection_timeout\" : 50000,\n "
685- + " \" read_timeout\" : 50000\n "
686- + " },\n "
687- + " \" protocol\" : \" http\" ,\n "
688- + " \" parameters\" : {\n "
689- + " \" endpoint\" : \" api.cohere.ai\" ,\n "
690- + " \" auth\" : \" API_Key\" ,\n "
691- + " \" content_type\" : \" application/json\" ,\n "
692- + " \" max_tokens\" : \" 20\" \n "
693- + " },\n "
694- + " \" credential\" : {\n "
695- + " \" cohere_key\" : \" "
696- + COHERE_KEY
697- + "\" \n "
698- + " },\n "
699- + " \" actions\" : [\n "
700- + " {\n "
701- + " \" action_type\" : \" predict\" ,\n "
702- + " \" method\" : \" POST\" ,\n "
703- + " \" url\" : \" https://${parameters.endpoint}/v1/generate\" ,\n "
704- + " \" headers\" : { \n "
705- + " \" Authorization\" : \" Bearer ${credential.cohere_key}\" \n "
706- + " },\n "
707- + " \" request_body\" : \" { \\ \" max_tokens\\ \" : ${parameters.max_tokens}, \\ \" return_likelihoods\\ \" : \\ \" NONE\\ \" , \\ \" truncate\\ \" : \\ \" END\\ \" , \\ \" prompt\\ \" : \\ \" ${parameters.prompt}\\ \" }\" \n "
708- + " }\n "
709- + " ]\n "
710- + "}" ;
711- Response response = createConnector (entity );
712- Map responseMap = parseResponseToMap (response );
713- String connectorId = (String ) responseMap .get ("connector_id" );
714- response = registerRemoteModel ("cohere generate text model" , connectorId );
715- responseMap = parseResponseToMap (response );
716- String taskId = (String ) responseMap .get ("task_id" );
717- waitForTask (taskId , MLTaskState .COMPLETED );
718- response = getTask (taskId );
719- responseMap = parseResponseToMap (response );
720- String modelId = (String ) responseMap .get ("model_id" );
721- response = deployRemoteModel (modelId );
722- responseMap = parseResponseToMap (response );
723- taskId = (String ) responseMap .get ("task_id" );
724- waitForTask (taskId , MLTaskState .COMPLETED );
725- String predictInput = "{\n "
726- + " \" parameters\" : {\n "
727- + " \" prompt\" : \" Once upon a time in a magical land called\" ,\n "
728- + " \" max_tokens\" : 40\n "
729- + " }\n "
730- + "}" ;
731- response = predictRemoteModel (modelId , predictInput );
732- responseMap = parseResponseToMap (response );
733- List responseList = (List ) responseMap .get ("inference_results" );
734- responseMap = (Map ) responseList .get (0 );
735- responseList = (List ) responseMap .get ("output" );
736- responseMap = (Map ) responseList .get (0 );
737- responseMap = (Map ) responseMap .get ("dataAsMap" );
738- responseList = (List ) responseMap .get ("generations" );
739- responseMap = (Map ) responseList .get (0 );
740- assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
741- }
742-
743673 public static Response createConnector (String input ) throws IOException {
744674 try {
745675 return TestHelper .makeRequest (client (), "POST" , "/_plugins/_ml/connectors/_create" , null , TestHelper .toHttpEntity (input ), null );
0 commit comments