Skip to content

Commit 3a1551e

Browse files
authored
[ML] Move to the Cohere V2 API for new inference endpoints (#129884)
1 parent 73b0a60 commit 3a1551e

File tree

58 files changed

+2206
-1359
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2206
-1359
lines changed

docs/changelog/129884.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129884
2+
summary: Move to the Cohere V2 API for new inference endpoints
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

muted-tests.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,6 @@ tests:
173173
- class: org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT
174174
method: test {yaml=analysis-common/40_token_filters/stemmer_override file access}
175175
issue: https://github.com/elastic/elasticsearch/issues/121625
176-
- class: org.elasticsearch.xpack.application.CohereServiceUpgradeIT
177-
issue: https://github.com/elastic/elasticsearch/issues/121537
178176
- class: org.elasticsearch.test.rest.ClientYamlTestSuiteIT
179177
method: test {yaml=snapshot.delete/10_basic/Delete a snapshot asynchronously}
180178
issue: https://github.com/elastic/elasticsearch/issues/122102

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ static TransportVersion def(int id) {
207207
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57);
208208
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58);
209209
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59);
210+
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_60);
210211
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
211212
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
212213
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -320,6 +321,7 @@ static TransportVersion def(int id) {
320321
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS = def(9_107_0_00);
321322
public static final TransportVersion CLUSTER_STATE_PROJECTS_SETTINGS = def(9_108_0_00);
322323
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_109_00_0);
324+
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_110_0_00);
323325

324326
/*
325327
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
import com.carrotsearch.randomizedtesting.annotations.Name;
1111

12+
import org.elasticsearch.client.ResponseException;
1213
import org.elasticsearch.common.Strings;
1314
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.test.http.MockRequest;
1416
import org.elasticsearch.test.http.MockResponse;
1517
import org.elasticsearch.test.http.MockWebServer;
1618
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
@@ -24,6 +26,7 @@
2426

2527
import static org.hamcrest.Matchers.anEmptyMap;
2628
import static org.hamcrest.Matchers.anyOf;
29+
import static org.hamcrest.Matchers.containsString;
2730
import static org.hamcrest.Matchers.empty;
2831
import static org.hamcrest.Matchers.hasEntry;
2932
import static org.hamcrest.Matchers.hasSize;
@@ -36,10 +39,16 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3639
// TODO: replace with proper test features
3740
private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0";
3841
private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0";
42+
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2";
3943

4044
private static MockWebServer cohereEmbeddingsServer;
4145
private static MockWebServer cohereRerankServer;
4246

47+
private enum ApiVersion {
48+
V1,
49+
V2
50+
}
51+
4352
public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) {
4453
super(upgradedNodes);
4554
}
@@ -62,15 +71,18 @@ public static void shutdown() {
6271
@SuppressWarnings("unchecked")
6372
public void testCohereEmbeddings() throws IOException {
6473
var embeddingsSupported = oldClusterHasFeature(COHERE_EMBEDDINGS_ADDED_TEST_FEATURE);
65-
String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
6674
assumeTrue("Cohere embedding service supported", embeddingsSupported);
6775

76+
String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
77+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;
78+
6879
final String oldClusterIdInt8 = "old-cluster-embeddings-int8";
6980
final String oldClusterIdFloat = "old-cluster-embeddings-float";
7081

7182
var testTaskType = TaskType.TEXT_EMBEDDING;
7283

7384
if (isOldCluster()) {
85+
7486
// queue a response as PUT will call the service
7587
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
7688
put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
@@ -128,13 +140,17 @@ public void testCohereEmbeddings() throws IOException {
128140

129141
// Inference on old cluster models
130142
assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
143+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion);
131144
assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
145+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion);
132146

133147
{
134148
final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte";
135149

150+
// new endpoints use the V2 API
136151
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
137152
put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType);
153+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
138154

139155
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdByte).get("endpoints");
140156
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
@@ -146,34 +162,70 @@ public void testCohereEmbeddings() throws IOException {
146162
{
147163
final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8";
148164

165+
// new endpoints use the V2 API
149166
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
150167
put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
168+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
151169

152170
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdInt8).get("endpoints");
153171
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
154172
assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte
155173

156174
assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8);
175+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
157176
delete(upgradedClusterIdInt8);
158177
}
159178
{
160179
final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float";
161180
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
162181
put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType);
182+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
163183

164184
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdFloat).get("endpoints");
165185
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
166186
assertThat(serviceSettings, hasEntry("embedding_type", "float"));
167187

168188
assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT);
189+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
169190
delete(upgradedClusterIdFloat);
170191
}
192+
{
193+
// new endpoints use the V2 API which require the model to be set
194+
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
195+
var jsonBody = Strings.format("""
196+
{
197+
"service": "cohere",
198+
"service_settings": {
199+
"url": "%s",
200+
"api_key": "XXXX",
201+
"embedding_type": "int8"
202+
}
203+
}
204+
""", getUrl(cohereEmbeddingsServer));
205+
206+
var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
207+
assertThat(
208+
e.getMessage(),
209+
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
210+
);
211+
}
171212

172213
delete(oldClusterIdFloat);
173214
delete(oldClusterIdInt8);
174215
}
175216
}
176217

218+
private void assertVersionInPath(MockRequest request, String endpoint, ApiVersion apiVersion) {
219+
switch (apiVersion) {
220+
case V2:
221+
assertEquals("/v2/" + endpoint, request.getUri().getPath());
222+
break;
223+
case V1:
224+
assertEquals("/v1/" + endpoint, request.getUri().getPath());
225+
break;
226+
}
227+
}
228+
177229
void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
178230
switch (type) {
179231
case INT8:
@@ -191,9 +243,11 @@ void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) thro
191243
@SuppressWarnings("unchecked")
192244
public void testRerank() throws IOException {
193245
var rerankSupported = oldClusterHasFeature(COHERE_RERANK_ADDED_TEST_FEATURE);
194-
String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
195246
assumeTrue("Cohere rerank service supported", rerankSupported);
196247

248+
String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
249+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;
250+
197251
final String oldClusterId = "old-cluster-rerank";
198252
final String upgradedClusterId = "upgraded-cluster-rerank";
199253

@@ -216,7 +270,6 @@ public void testRerank() throws IOException {
216270
assertThat(taskSettings, hasEntry("top_n", 3));
217271

218272
assertRerank(oldClusterId);
219-
220273
} else if (isUpgradedCluster()) {
221274
// check old cluster model
222275
var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get("endpoints");
@@ -227,6 +280,7 @@ public void testRerank() throws IOException {
227280
assertThat(taskSettings, hasEntry("top_n", 3));
228281

229282
assertRerank(oldClusterId);
283+
assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", oldClusterApiVersion);
230284

231285
// New endpoint
232286
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
@@ -235,6 +289,27 @@ public void testRerank() throws IOException {
235289
assertThat(configs, hasSize(1));
236290

237291
assertRerank(upgradedClusterId);
292+
assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", ApiVersion.V2);
293+
294+
{
295+
// new endpoints use the V2 API which require the model_id to be set
296+
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
297+
var jsonBody = Strings.format("""
298+
{
299+
"service": "cohere",
300+
"service_settings": {
301+
"url": "%s",
302+
"api_key": "XXXX"
303+
}
304+
}
305+
""", getUrl(cohereEmbeddingsServer));
306+
307+
var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
308+
assertThat(
309+
e.getMessage(),
310+
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
311+
);
312+
}
238313

239314
delete(oldClusterId);
240315
delete(upgradedClusterId);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public class InferenceFeatures implements FeatureSpecification {
3737
"test_rule_retriever.with_indices_that_dont_return_rank_docs"
3838
);
3939
private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
40+
private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");
4041

4142
@Override
4243
public Set<NodeFeature> getTestFeatures() {
@@ -64,7 +65,8 @@ public Set<NodeFeature> getTestFeatures() {
6465
SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG,
6566
SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
6667
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
67-
SEMANTIC_TEXT_INDEX_OPTIONS
68+
SEMANTIC_TEXT_INDEX_OPTIONS,
69+
COHERE_V2_API
6870
);
6971
}
7072
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
2727
private final Boolean returnDocuments;
2828
private final Integer topN;
2929

30-
public QueryAndDocsInputs(String query, List<String> chunks) {
31-
this(query, chunks, null, null, false);
32-
}
33-
3430
public QueryAndDocsInputs(
3531
String query,
3632
List<String> chunks,
@@ -45,6 +41,10 @@ public QueryAndDocsInputs(
4541
this.topN = topN;
4642
}
4743

44+
public QueryAndDocsInputs(String query, List<String> chunks) {
45+
this(query, chunks, null, null, false);
46+
}
47+
4848
public String getQuery() {
4949
return query;
5050
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,35 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere;
99

10-
import org.elasticsearch.common.CheckedSupplier;
10+
import org.apache.http.client.utils.URIBuilder;
11+
import org.elasticsearch.ElasticsearchStatusException;
12+
import org.elasticsearch.common.Strings;
1113
import org.elasticsearch.common.settings.SecureString;
14+
import org.elasticsearch.rest.RestStatus;
15+
import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils;
1216

1317
import java.net.URI;
1418
import java.net.URISyntaxException;
1519
import java.util.Objects;
1620

17-
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
18-
19-
public record CohereAccount(URI uri, SecureString apiKey) {
20-
21-
public static CohereAccount of(CohereModel model, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
22-
var uri = buildUri(model.uri(), "Cohere", uriBuilder);
23-
24-
return new CohereAccount(uri, model.apiKey());
21+
public record CohereAccount(URI baseUri, SecureString apiKey) {
22+
23+
public static CohereAccount of(CohereModel model) {
24+
try {
25+
var uri = model.baseUri() != null ? model.baseUri() : new URIBuilder().setScheme("https").setHost(CohereUtils.HOST).build();
26+
return new CohereAccount(uri, model.apiKey());
27+
} catch (URISyntaxException e) {
28+
// using bad request here so that potentially sensitive URL information does not get logged
29+
throw new ElasticsearchStatusException(
30+
Strings.format("Failed to construct %s URL", CohereService.NAME),
31+
RestStatus.BAD_REQUEST,
32+
e
33+
);
34+
}
2535
}
2636

2737
public CohereAccount {
28-
Objects.requireNonNull(uri);
38+
Objects.requireNonNull(baseUri);
2939
Objects.requireNonNull(apiKey);
3040
}
3141
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java

Lines changed: 0 additions & 62 deletions
This file was deleted.

0 commit comments

Comments
 (0)