Skip to content

Commit 97ce2b4

Browse files
committed
v2 messages format
1 parent 2a9f894 commit 97ce2b4

File tree

7 files changed

+61
-8
lines changed

7 files changed

+61
-8
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@ public class CohereUtils {
2828
public static final String DOCUMENTS_FIELD = "documents";
2929
public static final String EMBEDDING_TYPES_FIELD = "embedding_types";
3030
public static final String INPUT_TYPE_FIELD = "input_type";
31-
public static final String MESSAGE_FIELD = "message";
31+
public static final String V1_MESSAGE_FIELD = "message";
32+
public static final String V2_MESSAGES_FIELD = "messages";
3233
public static final String MODEL_FIELD = "model";
3334
public static final String QUERY_FIELD = "query";
35+
public static final String V2_ROLE_FIELD = "role";
3436
public static final String SEARCH_DOCUMENT = "search_document";
3537
public static final String SEARCH_QUERY = "search_query";
36-
public static final String TEXTS_FIELD = "texts";
3738
public static final String STREAM_FIELD = "stream";
39+
public static final String TEXTS_FIELD = "texts";
40+
public static final String USER_FIELD = "user";
3841

3942
public static Header createRequestSourceHeader() {
4043
return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public CohereV1CompletionRequest(List<String> input, CohereCompletionModel model
3030
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
3131
builder.startObject();
3232
// we only allow one input for completion, so always get the first one
33-
builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst());
33+
builder.field(CohereUtils.V1_MESSAGE_FIELD, input.getFirst());
3434
if (getModelId() != null) {
3535
builder.field(CohereUtils.MODEL_FIELD, getModelId());
3636
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ public CohereV2CompletionRequest(List<String> input, CohereCompletionModel model
2929
@Override
3030
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
3131
builder.startObject();
32+
builder.startArray(CohereUtils.V2_MESSAGES_FIELD);
33+
builder.startObject();
34+
builder.field(CohereUtils.V2_ROLE_FIELD, CohereUtils.USER_FIELD);
3235
// we only allow one input for completion, so always get the first one
33-
builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst());
36+
builder.field("content", input.getFirst());
37+
builder.endObject();
38+
builder.endArray();
3439
builder.field(CohereUtils.MODEL_FIELD, getModelId());
3540
builder.field(CohereUtils.STREAM_FIELD, isStreaming());
3641
builder.endObject();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep
209209
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret"));
210210

211211
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
212-
assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false)));
212+
assertThat(
213+
requestMap,
214+
is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false))
215+
);
213216
}
214217
}
215218
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO
132132
);
133133

134134
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
135-
assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false)));
135+
assertThat(
136+
requestMap,
137+
is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false))
138+
);
136139
}
137140
}
138141

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ public void testCreateRequest() throws IOException {
4646
assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE));
4747

4848
var requestMap = entityAsMap(httpPost.getEntity().getContent());
49-
assertThat(requestMap, is(Map.of("message", "abc", "model", "required model id", "stream", false)));
49+
assertThat(
50+
requestMap,
51+
is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "required model id", "stream", false))
52+
);
5053
}
5154

5255
public void testDefaultUrl() {
@@ -88,6 +91,6 @@ public void testXContents() throws IOException {
8891
String xContentResult = Strings.toString(builder);
8992

9093
assertThat(xContentResult, CoreMatchers.is("""
91-
{"message":"some input","model":"model","stream":false}"""));
94+
{"messages":[{"role":"user","content":"some input"}],"model":"model","stream":false}"""));
9295
}
9396
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,42 @@ public void testFromResponse_CreatesResponseEntityForText() throws IOException {
6464
assertThat(chatCompletionResults.getResults().get(0).content(), is("result"));
6565
}
6666

67+
public void testFromResponseV2() throws IOException {
68+
String responseJson = """
69+
{
70+
"id": "abc123",
71+
"finish_reason": "COMPLETE",
72+
"message": {
73+
"role": "assistant",
74+
"content": [
75+
{
76+
"type": "text",
77+
"text": "Response from the llm"
78+
}
79+
]
80+
},
81+
"usage": {
82+
"billed_units": {
83+
"input_tokens": 1,
84+
"output_tokens": 4
85+
},
86+
"tokens": {
87+
"input_tokens": 2,
88+
"output_tokens": 5
89+
}
90+
}
91+
}
92+
""";
93+
94+
ChatCompletionResults chatCompletionResults = CohereCompletionResponseEntity.fromResponse(
95+
mock(Request.class),
96+
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
97+
);
98+
99+
assertThat(chatCompletionResults.getResults().size(), is(1));
100+
assertThat(chatCompletionResults.getResults().get(0).content(), is("Response from the llm"));
101+
}
102+
67103
public void testFromResponse_FailsWhenTextIsNotPresent() {
68104
String responseJson = """
69105
{

0 commit comments

Comments
 (0)