Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ tests:
- class: org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT
method: test {yaml=analysis-common/40_token_filters/stemmer_override file access}
issue: https://github.com/elastic/elasticsearch/issues/121625
- class: org.elasticsearch.xpack.application.CohereServiceUpgradeIT
issue: https://github.com/elastic/elasticsearch/issues/121537
- class: org.elasticsearch.test.rest.ClientYamlTestSuiteIT
method: test {yaml=snapshot.delete/10_basic/Delete a snapshot asynchronously}
issue: https://github.com/elastic/elasticsearch/issues/122102
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ static TransportVersion def(int id) {
public static final TransportVersion STREAMS_LOGS_SUPPORT_8_19 = def(8_841_0_54);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_55);
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);

public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_57);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -311,6 +311,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00);
public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_105_0_00);
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_106_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import com.carrotsearch.randomizedtesting.annotations.Name;

import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockRequest;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
Expand All @@ -24,6 +26,7 @@

import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
Expand All @@ -36,10 +39,16 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
// TODO: replace with proper test features
private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0";
private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0";
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2";

private static MockWebServer cohereEmbeddingsServer;
private static MockWebServer cohereRerankServer;

private enum ApiVersion {
V1,
V2
}

public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) {
super(upgradedNodes);
}
Expand All @@ -62,15 +71,18 @@ public static void shutdown() {
@SuppressWarnings("unchecked")
public void testCohereEmbeddings() throws IOException {
var embeddingsSupported = oldClusterHasFeature(COHERE_EMBEDDINGS_ADDED_TEST_FEATURE);
String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
assumeTrue("Cohere embedding service supported", embeddingsSupported);

String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;

final String oldClusterIdInt8 = "old-cluster-embeddings-int8";
final String oldClusterIdFloat = "old-cluster-embeddings-float";

var testTaskType = TaskType.TEXT_EMBEDDING;

if (isOldCluster()) {

// queue a response as PUT will call the service
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
Expand Down Expand Up @@ -128,13 +140,17 @@ public void testCohereEmbeddings() throws IOException {

// Inference on old cluster models
assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion);
assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion);

{
final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte";

// new endpoints use the V2 API
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);

configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdByte).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
Expand All @@ -146,34 +162,70 @@ public void testCohereEmbeddings() throws IOException {
{
final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8";

// new endpoints use the V2 API
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);

configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdInt8).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte

assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
delete(upgradedClusterIdInt8);
}
{
final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float";
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);

configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdFloat).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "float"));

assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
delete(upgradedClusterIdFloat);
}
{
// new endpoints use the V2 API which require the model to be set
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
var jsonBody = Strings.format("""
{
"service": "cohere",
"service_settings": {
"url": "%s",
"api_key": "XXXX",
"embedding_type": "int8"
}
}
""", getUrl(cohereEmbeddingsServer));

var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
assertThat(
e.getMessage(),
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
);
}

delete(oldClusterIdFloat);
delete(oldClusterIdInt8);
}
}

private void assertVersionInPath(MockRequest request, String endpoint, ApiVersion apiVersion) {
switch (apiVersion) {
case V2:
assertEquals("/v2/" + endpoint, request.getUri().getPath());
break;
case V1:
assertEquals("/v1/" + endpoint, request.getUri().getPath());
break;
}
}

void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
switch (type) {
case INT8:
Expand All @@ -191,9 +243,11 @@ void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) thro
@SuppressWarnings("unchecked")
public void testRerank() throws IOException {
var rerankSupported = oldClusterHasFeature(COHERE_RERANK_ADDED_TEST_FEATURE);
String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
assumeTrue("Cohere rerank service supported", rerankSupported);

String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;

final String oldClusterId = "old-cluster-rerank";
final String upgradedClusterId = "upgraded-cluster-rerank";

Expand All @@ -216,7 +270,6 @@ public void testRerank() throws IOException {
assertThat(taskSettings, hasEntry("top_n", 3));

assertRerank(oldClusterId);

} else if (isUpgradedCluster()) {
// check old cluster model
var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get("endpoints");
Expand All @@ -227,6 +280,7 @@ public void testRerank() throws IOException {
assertThat(taskSettings, hasEntry("top_n", 3));

assertRerank(oldClusterId);
assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", oldClusterApiVersion);

// New endpoint
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
Expand All @@ -235,6 +289,27 @@ public void testRerank() throws IOException {
assertThat(configs, hasSize(1));

assertRerank(upgradedClusterId);
assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", ApiVersion.V2);

{
// new endpoints use the V2 API which require the model_id to be set
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
var jsonBody = Strings.format("""
{
"service": "cohere",
"service_settings": {
"url": "%s",
"api_key": "XXXX"
}
}
""", getUrl(cohereEmbeddingsServer));

var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
assertThat(
e.getMessage(),
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
);
}

delete(oldClusterId);
delete(upgradedClusterId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class InferenceFeatures implements FeatureSpecification {
"test_rule_retriever.with_indices_that_dont_return_rank_docs"
);
private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");

@Override
public Set<NodeFeature> getTestFeatures() {
Expand Down Expand Up @@ -64,7 +65,8 @@ public Set<NodeFeature> getTestFeatures() {
SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG,
SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
SEMANTIC_TEXT_INDEX_OPTIONS
SEMANTIC_TEXT_INDEX_OPTIONS,
COHERE_V2_API
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
private final Boolean returnDocuments;
private final Integer topN;

public QueryAndDocsInputs(String query, List<String> chunks) {
this(query, chunks, null, null, false);
}

public QueryAndDocsInputs(
String query,
List<String> chunks,
Expand All @@ -45,6 +41,10 @@ public QueryAndDocsInputs(
this.topN = topN;
}

public QueryAndDocsInputs(String query, List<String> chunks) {
this(query, chunks, null, null, false);
}

public String getQuery() {
return query;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,35 @@

package org.elasticsearch.xpack.inference.services.cohere;

import org.elasticsearch.common.CheckedSupplier;
import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;

public record CohereAccount(URI uri, SecureString apiKey) {

public static CohereAccount of(CohereModel model, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
var uri = buildUri(model.uri(), "Cohere", uriBuilder);

return new CohereAccount(uri, model.apiKey());
public record CohereAccount(URI baseUri, SecureString apiKey) {

public static CohereAccount of(CohereModel model) {
try {
var uri = model.baseUri() != null ? model.baseUri() : new URIBuilder().setScheme("https").setHost(CohereUtils.HOST).build();
return new CohereAccount(uri, model.apiKey());
} catch (URISyntaxException e) {
// using bad request here so that potentially sensitive URL information does not get logged
throw new ElasticsearchStatusException(
Strings.format("Failed to construct %s URL", CohereService.NAME),
RestStatus.BAD_REQUEST,
e
);
}
}

public CohereAccount {
Objects.requireNonNull(uri);
Objects.requireNonNull(baseUri);
Objects.requireNonNull(apiKey);
}
}

This file was deleted.

Loading