diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java index 9782d4881ac61..6191e83a7dca1 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java @@ -39,10 +39,12 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase { // TODO: replace with proper test features private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0"; private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0"; + private static final String COHERE_COMPLETIONS_ADDED_TEST_FEATURE = "gte_v8.15.0"; private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2"; private static MockWebServer cohereEmbeddingsServer; private static MockWebServer cohereRerankServer; + private static MockWebServer cohereCompletionsServer; private enum ApiVersion { V1, @@ -60,12 +62,16 @@ public static void startWebServer() throws IOException { cohereRerankServer = new MockWebServer(); cohereRerankServer.start(); + + cohereCompletionsServer = new MockWebServer(); + cohereCompletionsServer.start(); } @AfterClass public static void shutdown() { cohereEmbeddingsServer.close(); cohereRerankServer.close(); + cohereCompletionsServer.close(); } @SuppressWarnings("unchecked") @@ -326,6 +332,80 @@ private void assertRerank(String inferenceId) throws IOException { assertThat(inferenceMap.entrySet(), not(empty())); } + @SuppressWarnings("unchecked") + public void testCohereCompletions() throws IOException { + var completionsSupported = oldClusterHasFeature(COHERE_COMPLETIONS_ADDED_TEST_FEATURE); + assumeTrue("Cohere completions not supported", completionsSupported); + + ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1; + + final String oldClusterId = "old-cluster-completions"; + + if (isOldCluster()) { + // queue a response as PUT will call the service + cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion))); + put(oldClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION); + + var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("endpoints"); + assertThat(configs, hasSize(1)); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "command")); + } else if (isMixedCluster()) { + var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("endpoints"); + assertThat(configs, hasSize(1)); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "command")); + } else if (isUpgradedCluster()) { + // check old cluster model + var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("endpoints"); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "command")); + + final String newClusterId = "new-cluster-completions"; + { + cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion))); + var inferenceMap = inference(oldClusterId, TaskType.COMPLETION, "some text"); + assertThat(inferenceMap.entrySet(), not(empty())); + assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", oldClusterApiVersion); + } + { + // new cluster uses the V2 API + cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2))); + put(newClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION); + + cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2))); + var inferenceMap = inference(newClusterId, TaskType.COMPLETION, "some text"); + assertThat(inferenceMap.entrySet(), not(empty())); + assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", ApiVersion.V2); + } + + { + // new endpoints use the V2 API which require the model to be set + final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id"; + var jsonBody = Strings.format(""" + { + "service": "cohere", + "service_settings": { + "url": "%s", + "api_key": "XXXX" + } + } + """, getUrl(cohereEmbeddingsServer)); + + var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, TaskType.COMPLETION)); + assertThat( + e.getMessage(), + containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.") + ); + } + + delete(oldClusterId); + delete(newClusterId); + } + } + private String embeddingConfigByte(String url) { return embeddingConfigTemplate(url, "byte"); } @@ -451,4 +531,86 @@ private String rerankResponse() { """; } + private String completionsConfig(String url) { + return Strings.format(""" + { + "service": "cohere", + "service_settings": { + "api_key": "XXXX", + "model_id": "command", + "url": "%s" + } + } + """, url); + } + + private String completionsResponse(ApiVersion version) { + return switch (version) { + case V1 -> v1CompletionsResponse(); + case V2 -> v2CompletionsResponse(); + }; + } + + private String v1CompletionsResponse() { + return """ + { + "response_id": "some id", + "text": "result", + "generation_id": "some id", + "chat_history": [ + { + "role": "USER", + "message": "some input" + }, + { + "role": "CHATBOT", + "message": "v1 response from the llm" + } + ], + "finish_reason": "COMPLETE", + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 4, + "output_tokens": 191 + }, + "tokens": { + "input_tokens": 70, + "output_tokens": 191 + } + } + } + """; + } + + private String v2CompletionsResponse() { + return """ + { + "id": "c14c80c3-18eb-4519-9460-6c92edd8cfb4", + "finish_reason": "COMPLETE", + "message": { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "v2 response from the LLM" + } + ] + }, + "usage": { + "billed_units": { + "input_tokens": 1, + "output_tokens": 2 + }, + "tokens": { + "input_tokens": 3, + "output_tokens": 4 + } + } + } + """; + } + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java index f512444c6d6a4..2d52a8a9dadbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java @@ -28,13 +28,16 @@ public class CohereUtils { public static final String DOCUMENTS_FIELD = "documents"; public static final String EMBEDDING_TYPES_FIELD = "embedding_types"; public static final String INPUT_TYPE_FIELD = "input_type"; - public static final String MESSAGE_FIELD = "message"; + public static final String V1_MESSAGE_FIELD = "message"; + public static final String V2_MESSAGES_FIELD = "messages"; public static final String MODEL_FIELD = "model"; public static final String QUERY_FIELD = "query"; + public static final String V2_ROLE_FIELD = "role"; public static final String SEARCH_DOCUMENT = "search_document"; public static final String SEARCH_QUERY = "search_query"; - public static final String TEXTS_FIELD = "texts"; public static final String STREAM_FIELD = "stream"; + public static final String TEXTS_FIELD = "texts"; + public static final String USER_FIELD = "user"; public static Header createRequestSourceHeader() { return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java index 4fa4552dcd94d..0be1ba8d25f29 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java @@ -30,7 +30,7 @@ public CohereV1CompletionRequest(List input, CohereCompletionModel model public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); // we only allow one input for completion, so always get the first one - builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst()); + builder.field(CohereUtils.V1_MESSAGE_FIELD, input.getFirst()); if (getModelId() != null) { builder.field(CohereUtils.MODEL_FIELD, getModelId()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java index 028c4a0d486c0..1a8eae321ac77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java @@ -29,8 +29,13 @@ public CohereV2CompletionRequest(List input, CohereCompletionModel model @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); + builder.startArray(CohereUtils.V2_MESSAGES_FIELD); + builder.startObject(); + builder.field(CohereUtils.V2_ROLE_FIELD, CohereUtils.USER_FIELD); // we only allow one input for completion, so always get the first one - builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst()); + builder.field("content", input.getFirst()); + builder.endObject(); + builder.endArray(); builder.field(CohereUtils.MODEL_FIELD, getModelId()); builder.field(CohereUtils.STREAM_FIELD, isStreaming()); builder.endObject(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index 88d26d5d7eef1..6438a328f9fcf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -209,7 +209,10 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false))); + assertThat( + requestMap, + is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false)) + ); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java index 78b8b7bdeaf3e..6c5128956fc9b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java @@ -132,7 +132,10 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false))); + assertThat( + requestMap, + is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false)) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java index 2fb51ca8ca457..6003a58bf0340 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java @@ -46,7 +46,10 @@ public void testCreateRequest() throws IOException { assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "required model id", "stream", false))); + assertThat( + requestMap, + is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "required model id", "stream", false)) + ); } public void testDefaultUrl() { @@ -88,6 +91,6 @@ public void testXContents() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, CoreMatchers.is(""" - {"message":"some input","model":"model","stream":false}""")); + {"messages":[{"role":"user","content":"some input"}],"model":"model","stream":false}""")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java index 4a60dc5033e22..5d7a76a26e597 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java @@ -64,6 +64,42 @@ public void testFromResponse_CreatesResponseEntityForText() throws IOException { assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); } + public void testFromResponseV2() throws IOException { + String responseJson = """ + { + "id": "abc123", + "finish_reason": "COMPLETE", + "message": { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Response from the llm" + } + ] + }, + "usage": { + "billed_units": { + "input_tokens": 1, + "output_tokens": 4 + }, + "tokens": { + "input_tokens": 2, + "output_tokens": 5 + } + } + } + """; + + ChatCompletionResults chatCompletionResults = CohereCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("Response from the llm")); + } + public void testFromResponse_FailsWhenTextIsNotPresent() { String responseJson = """ {