Skip to content

Commit 25c2346

Browse files
xinyualbrianf-aws
authored andcommitted
fix Cohere IT (opensearch-project#4174)
* fix Cohere IT Signed-off-by: xinyual <[email protected]> * apply spotless Signed-off-by: xinyual <[email protected]> * delete useless it Signed-off-by: xinyual <[email protected]> --------- Signed-off-by: xinyual <[email protected]>
1 parent 2e219ee commit 25c2346

File tree

2 files changed

+1
-71
lines changed

2 files changed

+1
-71
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
333333
+ "\"\n"
334334
+ " },\n"
335335
+ " \"parameters\": {\n"
336-
+ " \"model\": \"command\"\n"
336+
+ " \"model\": \"command-a-03-2025\"\n"
337337
+ " },\n"
338338
+ " \"actions\": [\n"
339339
+ " {\n"

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -670,76 +670,6 @@ private void testOpenAITextEmbeddingModel(String charset, Consumer<Map> verifyRe
670670
}
671671
}
672672

673-
public void testCohereGenerateTextModel() throws IOException, InterruptedException {
674-
// Skip test if key is null
675-
if (COHERE_KEY == null) {
676-
return;
677-
}
678-
String entity = "{\n"
679-
+ " \"name\": \"Cohere generate text model Connector\",\n"
680-
+ " \"description\": \"The connector to public Cohere generate text model service\",\n"
681-
+ " \"version\": 1,\n"
682-
+ "\"client_config\": {\n"
683-
+ " \"max_connection\": 20,\n"
684-
+ " \"connection_timeout\": 50000,\n"
685-
+ " \"read_timeout\": 50000\n"
686-
+ " },\n"
687-
+ " \"protocol\": \"http\",\n"
688-
+ " \"parameters\": {\n"
689-
+ " \"endpoint\": \"api.cohere.ai\",\n"
690-
+ " \"auth\": \"API_Key\",\n"
691-
+ " \"content_type\": \"application/json\",\n"
692-
+ " \"max_tokens\": \"20\"\n"
693-
+ " },\n"
694-
+ " \"credential\": {\n"
695-
+ " \"cohere_key\": \""
696-
+ COHERE_KEY
697-
+ "\"\n"
698-
+ " },\n"
699-
+ " \"actions\": [\n"
700-
+ " {\n"
701-
+ " \"action_type\": \"predict\",\n"
702-
+ " \"method\": \"POST\",\n"
703-
+ " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n"
704-
+ " \"headers\": { \n"
705-
+ " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n"
706-
+ " },\n"
707-
+ " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n"
708-
+ " }\n"
709-
+ " ]\n"
710-
+ "}";
711-
Response response = createConnector(entity);
712-
Map responseMap = parseResponseToMap(response);
713-
String connectorId = (String) responseMap.get("connector_id");
714-
response = registerRemoteModel("cohere generate text model", connectorId);
715-
responseMap = parseResponseToMap(response);
716-
String taskId = (String) responseMap.get("task_id");
717-
waitForTask(taskId, MLTaskState.COMPLETED);
718-
response = getTask(taskId);
719-
responseMap = parseResponseToMap(response);
720-
String modelId = (String) responseMap.get("model_id");
721-
response = deployRemoteModel(modelId);
722-
responseMap = parseResponseToMap(response);
723-
taskId = (String) responseMap.get("task_id");
724-
waitForTask(taskId, MLTaskState.COMPLETED);
725-
String predictInput = "{\n"
726-
+ " \"parameters\": {\n"
727-
+ " \"prompt\": \"Once upon a time in a magical land called\",\n"
728-
+ " \"max_tokens\": 40\n"
729-
+ " }\n"
730-
+ "}";
731-
response = predictRemoteModel(modelId, predictInput);
732-
responseMap = parseResponseToMap(response);
733-
List responseList = (List) responseMap.get("inference_results");
734-
responseMap = (Map) responseList.get(0);
735-
responseList = (List) responseMap.get("output");
736-
responseMap = (Map) responseList.get(0);
737-
responseMap = (Map) responseMap.get("dataAsMap");
738-
responseList = (List) responseMap.get("generations");
739-
responseMap = (Map) responseList.get(0);
740-
assertFalse(((String) responseMap.get("text")).isEmpty());
741-
}
742-
743673
public static Response createConnector(String input) throws IOException {
744674
try {
745675
return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null);

0 commit comments

Comments
 (0)