Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
// TODO: replace with proper test features
private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0";
private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0";
private static final String COHERE_COMPLETIONS_ADDED_TEST_FEATURE = "gte_v8.15.0";
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2";

private static MockWebServer cohereEmbeddingsServer;
private static MockWebServer cohereRerankServer;
private static MockWebServer cohereCompletionsServer;

private enum ApiVersion {
V1,
Expand All @@ -60,12 +62,16 @@ public static void startWebServer() throws IOException {

cohereRerankServer = new MockWebServer();
cohereRerankServer.start();

cohereCompletionsServer = new MockWebServer();
cohereCompletionsServer.start();
}

@AfterClass
public static void shutdown() {
cohereEmbeddingsServer.close();
cohereRerankServer.close();
cohereCompletionsServer.close();
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -326,6 +332,80 @@ private void assertRerank(String inferenceId) throws IOException {
assertThat(inferenceMap.entrySet(), not(empty()));
}

@SuppressWarnings("unchecked")
public void testCohereCompletions() throws IOException {
var completionsSupported = oldClusterHasFeature(COHERE_COMPLETIONS_ADDED_TEST_FEATURE);
assumeTrue("Cohere completions not supported", completionsSupported);

ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;

final String oldClusterId = "old-cluster-completions";

if (isOldCluster()) {
// queue a response as PUT will call the service
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion)));
put(oldClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION);

var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("cohere", configs.get(0).get("service"));
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("model_id", "command"));
} else if (isMixedCluster()) {
var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("cohere", configs.get(0).get("service"));
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("model_id", "command"));
} else if (isUpgradedCluster()) {
// check old cluster model
var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("endpoints");
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("model_id", "command"));

final String newClusterId = "new-cluster-completions";
{
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion)));
var inferenceMap = inference(oldClusterId, TaskType.COMPLETION, "some text");
assertThat(inferenceMap.entrySet(), not(empty()));
assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", oldClusterApiVersion);
}
{
// new cluster uses the V2 API
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2)));
put(newClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION);

cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2)));
var inferenceMap = inference(newClusterId, TaskType.COMPLETION, "some text");
assertThat(inferenceMap.entrySet(), not(empty()));
assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", ApiVersion.V2);
}

{
// new endpoints use the V2 API which require the model to be set
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
var jsonBody = Strings.format("""
{
"service": "cohere",
"service_settings": {
"url": "%s",
"api_key": "XXXX"
}
}
""", getUrl(cohereEmbeddingsServer));

var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, TaskType.COMPLETION));
assertThat(
e.getMessage(),
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
);
}

delete(oldClusterId);
delete(newClusterId);
}
}

private String embeddingConfigByte(String url) {
return embeddingConfigTemplate(url, "byte");
}
Expand Down Expand Up @@ -451,4 +531,86 @@ private String rerankResponse() {
""";
}

private String completionsConfig(String url) {
return Strings.format("""
{
"service": "cohere",
"service_settings": {
"api_key": "XXXX",
"model_id": "command",
"url": "%s"
}
}
""", url);
}

private String completionsResponse(ApiVersion version) {
return switch (version) {
case V1 -> v1CompletionsResponse();
case V2 -> v2CompletionsResponse();
};
}

private String v1CompletionsResponse() {
return """
{
"response_id": "some id",
"text": "result",
"generation_id": "some id",
"chat_history": [
{
"role": "USER",
"message": "some input"
},
{
"role": "CHATBOT",
"message": "v1 response from the llm"
}
],
"finish_reason": "COMPLETE",
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"input_tokens": 4,
"output_tokens": 191
},
"tokens": {
"input_tokens": 70,
"output_tokens": 191
}
}
}
""";
}

private String v2CompletionsResponse() {
return """
{
"id": "c14c80c3-18eb-4519-9460-6c92edd8cfb4",
"finish_reason": "COMPLETE",
"message": {
"role": "assistant",
"content": [
{
"type": "text",
"text": "v2 response from the LLM"
}
]
},
"usage": {
"billed_units": {
"input_tokens": 1,
"output_tokens": 2
},
"tokens": {
"input_tokens": 3,
"output_tokens": 4
}
}
}
""";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ public class CohereUtils {
public static final String DOCUMENTS_FIELD = "documents";
public static final String EMBEDDING_TYPES_FIELD = "embedding_types";
public static final String INPUT_TYPE_FIELD = "input_type";
public static final String MESSAGE_FIELD = "message";
public static final String V1_MESSAGE_FIELD = "message";
public static final String V2_MESSAGES_FIELD = "messages";
public static final String MODEL_FIELD = "model";
public static final String QUERY_FIELD = "query";
public static final String V2_ROLE_FIELD = "role";
public static final String SEARCH_DOCUMENT = "search_document";
public static final String SEARCH_QUERY = "search_query";
public static final String TEXTS_FIELD = "texts";
public static final String STREAM_FIELD = "stream";
public static final String TEXTS_FIELD = "texts";
public static final String USER_FIELD = "user";

public static Header createRequestSourceHeader() {
return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public CohereV1CompletionRequest(List<String> input, CohereCompletionModel model
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
// we only allow one input for completion, so always get the first one
builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst());
builder.field(CohereUtils.V1_MESSAGE_FIELD, input.getFirst());
if (getModelId() != null) {
builder.field(CohereUtils.MODEL_FIELD, getModelId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ public CohereV2CompletionRequest(List<String> input, CohereCompletionModel model
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startArray(CohereUtils.V2_MESSAGES_FIELD);
builder.startObject();
builder.field(CohereUtils.V2_ROLE_FIELD, CohereUtils.USER_FIELD);
// we only allow one input for completion, so always get the first one
builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst());
builder.field("content", input.getFirst());
builder.endObject();
builder.endArray();
builder.field(CohereUtils.MODEL_FIELD, getModelId());
builder.field(CohereUtils.STREAM_FIELD, isStreaming());
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret"));

var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false)));
assertThat(
requestMap,
is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false))
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO
);

var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false)));
assertThat(
requestMap,
is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false))
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ public void testCreateRequest() throws IOException {
assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE));

var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, is(Map.of("message", "abc", "model", "required model id", "stream", false)));
assertThat(
requestMap,
is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "required model id", "stream", false))
);
}

public void testDefaultUrl() {
Expand Down Expand Up @@ -88,6 +91,6 @@ public void testXContents() throws IOException {
String xContentResult = Strings.toString(builder);

assertThat(xContentResult, CoreMatchers.is("""
{"message":"some input","model":"model","stream":false}"""));
{"messages":[{"role":"user","content":"some input"}],"model":"model","stream":false}"""));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,42 @@ public void testFromResponse_CreatesResponseEntityForText() throws IOException {
assertThat(chatCompletionResults.getResults().get(0).content(), is("result"));
}

public void testFromResponseV2() throws IOException {
String responseJson = """
{
"id": "abc123",
"finish_reason": "COMPLETE",
"message": {
"role": "assistant",
"content": [
{
"type": "text",
"text": "Response from the llm"
}
]
},
"usage": {
"billed_units": {
"input_tokens": 1,
"output_tokens": 4
},
"tokens": {
"input_tokens": 2,
"output_tokens": 5
}
}
}
""";

ChatCompletionResults chatCompletionResults = CohereCompletionResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);

assertThat(chatCompletionResults.getResults().size(), is(1));
assertThat(chatCompletionResults.getResults().get(0).content(), is("Response from the llm"));
}

public void testFromResponse_FailsWhenTextIsNotPresent() {
String responseJson = """
{
Expand Down