Skip to content

Commit 92a373b

Browse files
committed
Fix the tests
1 parent d37dd1f commit 92a373b

26 files changed

+658
-253
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ static TransportVersion def(int id) {
202202
public static final TransportVersion STREAMS_LOGS_SUPPORT_8_19 = def(8_841_0_54);
203203
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_55);
204204
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);
205-
205+
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_57);
206206
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
207207
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
208208
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -311,6 +311,7 @@ static TransportVersion def(int id) {
311311
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00);
312312
public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00);
313313
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_105_0_00);
314+
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_106_0_00);
314315

315316
/*
316317
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import com.carrotsearch.randomizedtesting.annotations.Name;
1111

12+
import org.elasticsearch.client.ResponseException;
1213
import org.elasticsearch.common.Strings;
1314
import org.elasticsearch.inference.TaskType;
1415
import org.elasticsearch.test.http.MockRequest;
@@ -25,6 +26,7 @@
2526

2627
import static org.hamcrest.Matchers.anEmptyMap;
2728
import static org.hamcrest.Matchers.anyOf;
29+
import static org.hamcrest.Matchers.containsString;
2830
import static org.hamcrest.Matchers.empty;
2931
import static org.hamcrest.Matchers.hasEntry;
3032
import static org.hamcrest.Matchers.hasSize;
@@ -37,13 +39,14 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3739
// TODO: replace with proper test features
3840
private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0";
3941
private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0";
40-
private static final String V2_API = "gte_v8.19.0";
42+
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2";
4143

4244
private static MockWebServer cohereEmbeddingsServer;
4345
private static MockWebServer cohereRerankServer;
4446

4547
private enum ApiVersion {
46-
V1, V2
48+
V1,
49+
V2
4750
}
4851

4952
public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) {
@@ -71,7 +74,7 @@ public void testCohereEmbeddings() throws IOException {
7174
assumeTrue("Cohere embedding service supported", embeddingsSupported);
7275

7376
String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
74-
ApiVersion oldClusterApiVersion = oldClusterHasFeature(V2_API) ? ApiVersion.V2 : ApiVersion.V1;
77+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;
7578

7679
final String oldClusterIdInt8 = "old-cluster-embeddings-int8";
7780
final String oldClusterIdFloat = "old-cluster-embeddings-float";
@@ -186,6 +189,26 @@ public void testCohereEmbeddings() throws IOException {
186189
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
187190
delete(upgradedClusterIdFloat);
188191
}
192+
{
193+
// new endpoints use the V2 API which require the model to be set
194+
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
195+
var jsonBody = Strings.format("""
196+
{
197+
"service": "cohere",
198+
"service_settings": {
199+
"url": "%s",
200+
"api_key": "XXXX",
201+
"embedding_type": "int8"
202+
}
203+
}
204+
""", getUrl(cohereEmbeddingsServer));
205+
206+
var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
207+
assertThat(
208+
e.getMessage(),
209+
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
210+
);
211+
}
189212

190213
delete(oldClusterIdFloat);
191214
delete(oldClusterIdInt8);
@@ -223,7 +246,7 @@ public void testRerank() throws IOException {
223246
assumeTrue("Cohere rerank service supported", rerankSupported);
224247

225248
String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
226-
ApiVersion oldClusterApiVersion = oldClusterHasFeature(V2_API) ? ApiVersion.V2 : ApiVersion.V1;
249+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;
227250

228251
final String oldClusterId = "old-cluster-rerank";
229252
final String upgradedClusterId = "upgraded-cluster-rerank";
@@ -268,6 +291,26 @@ public void testRerank() throws IOException {
268291
assertRerank(upgradedClusterId);
269292
assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", ApiVersion.V2);
270293

294+
{
295+
// new endpoints use the V2 API which require the model_id to be set
296+
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
297+
var jsonBody = Strings.format("""
298+
{
299+
"service": "cohere",
300+
"service_settings": {
301+
"url": "%s",
302+
"api_key": "XXXX"
303+
}
304+
}
305+
""", getUrl(cohereEmbeddingsServer));
306+
307+
var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
308+
assertThat(
309+
e.getMessage(),
310+
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
311+
);
312+
}
313+
271314
delete(oldClusterId);
272315
delete(upgradedClusterId);
273316
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public class InferenceFeatures implements FeatureSpecification {
3737
"test_rule_retriever.with_indices_that_dont_return_rank_docs"
3838
);
3939
private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
40+
private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");
4041

4142
@Override
4243
public Set<NodeFeature> getTestFeatures() {
@@ -64,7 +65,8 @@ public Set<NodeFeature> getTestFeatures() {
6465
SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG,
6566
SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
6667
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
67-
SEMANTIC_TEXT_INDEX_OPTIONS
68+
SEMANTIC_TEXT_INDEX_OPTIONS,
69+
COHERE_V2_API
6870
);
6971
}
7072
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ public QueryAndDocsInputs(
4141
this.topN = topN;
4242
}
4343

44+
public QueryAndDocsInputs(String query, List<String> chunks) {
45+
this(query, chunks, null, null, false);
46+
}
47+
4448
public String getQuery() {
4549
return query;
4650
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Objects;
2626

2727
public abstract class CohereModel extends RateLimitGroupingModel {
28+
2829
private final SecureString apiKey;
2930
private final CohereRateLimitServiceSettings rateLimitServiceSettings;
3031

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,7 @@
1414
public interface CohereRateLimitServiceSettings {
1515
RateLimitSettings rateLimitSettings();
1616

17+
CohereServiceSettings.CohereApiVersion apiVersion();
18+
1719
URI uri();
1820
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
314314
embeddingSize,
315315
serviceSettings.getCommonSettings().maxInputTokens(),
316316
serviceSettings.getCommonSettings().modelId(),
317-
serviceSettings.getCommonSettings().rateLimitSettings()
317+
serviceSettings.getCommonSettings().rateLimitSettings(),
318+
serviceSettings.getCommonSettings().apiVersion()
318319
),
319320
serviceSettings.getEmbeddingType()
320321
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020
import org.elasticsearch.inference.SimilarityMeasure;
2121
import org.elasticsearch.xcontent.XContentBuilder;
2222
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
23+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
2324
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
2425
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2526

2627
import java.io.IOException;
2728
import java.net.URI;
29+
import java.util.EnumSet;
30+
import java.util.Locale;
2831
import java.util.Map;
2932
import java.util.Objects;
3033

@@ -43,6 +46,18 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
4346
public static final String NAME = "cohere_service_settings";
4447
public static final String OLD_MODEL_ID_FIELD = "model";
4548
public static final String MODEL_ID = "model_id";
49+
public static final String API_VERSION = "api_version";
50+
public static final String MODEL_REQUIRED_FOR_V2_API = "The [service_settings.model_id] field is required for the Cohere V2 API.";
51+
52+
public enum CohereApiVersion {
53+
V1,
54+
V2;
55+
56+
public static CohereApiVersion fromString(String name) {
57+
return valueOf(name.trim().toUpperCase(Locale.ROOT));
58+
}
59+
}
60+
4661
private static final Logger logger = LogManager.getLogger(CohereServiceSettings.class);
4762
// Production key rate limits for all endpoints: https://docs.cohere.com/docs/going-live#production-key-specifications
4863
// 10K requests a minute
@@ -72,11 +87,53 @@ public static CohereServiceSettings fromMap(Map<String, Object> map, Configurati
7287
logger.info("The cohere [service_settings.model] field is deprecated. Please use [service_settings.model_id] instead.");
7388
}
7489

90+
var resolvedModelId = modelId(oldModelId, modelId);
91+
var apiVersion = apiVersionFromMap(map, context, validationException);
92+
if (apiVersion == CohereApiVersion.V2) {
93+
if (resolvedModelId == null) {
94+
validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API);
95+
}
96+
}
97+
7598
if (validationException.validationErrors().isEmpty() == false) {
7699
throw validationException;
77100
}
78101

79-
return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, modelId(oldModelId, modelId), rateLimitSettings);
102+
return new CohereServiceSettings(
103+
uri,
104+
similarity,
105+
dims,
106+
maxInputTokens,
107+
modelId(oldModelId, modelId),
108+
rateLimitSettings,
109+
apiVersion
110+
);
111+
}
112+
113+
public static CohereApiVersion apiVersionFromMap(
114+
Map<String, Object> map,
115+
ConfigurationParseContext context,
116+
ValidationException validationException
117+
) {
118+
return switch (context) {
119+
case REQUEST -> CohereApiVersion.V2; // new endpoints all use the V2 API.
120+
case PERSISTENT -> {
121+
var apiVersion = ServiceUtils.extractOptionalEnum(
122+
map,
123+
API_VERSION,
124+
ModelConfigurations.SERVICE_SETTINGS,
125+
CohereApiVersion::fromString,
126+
EnumSet.allOf(CohereApiVersion.class),
127+
validationException
128+
);
129+
130+
if (apiVersion == null) {
131+
yield CohereApiVersion.V1; // If the API version is not persisted then it must be V1
132+
} else {
133+
yield apiVersion;
134+
}
135+
}
136+
};
80137
}
81138

82139
private static String modelId(@Nullable String model, @Nullable String modelId) {
@@ -89,21 +146,24 @@ private static String modelId(@Nullable String model, @Nullable String modelId)
89146
private final Integer maxInputTokens;
90147
private final String modelId;
91148
private final RateLimitSettings rateLimitSettings;
149+
private final CohereApiVersion apiVersion;
92150

93151
public CohereServiceSettings(
94152
@Nullable URI uri,
95153
@Nullable SimilarityMeasure similarity,
96154
@Nullable Integer dimensions,
97155
@Nullable Integer maxInputTokens,
98156
@Nullable String modelId,
99-
@Nullable RateLimitSettings rateLimitSettings
157+
@Nullable RateLimitSettings rateLimitSettings,
158+
CohereApiVersion apiVersion
100159
) {
101160
this.uri = uri;
102161
this.similarity = similarity;
103162
this.dimensions = dimensions;
104163
this.maxInputTokens = maxInputTokens;
105164
this.modelId = modelId;
106165
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
166+
this.apiVersion = apiVersion;
107167
}
108168

109169
public CohereServiceSettings(
@@ -112,9 +172,10 @@ public CohereServiceSettings(
112172
@Nullable Integer dimensions,
113173
@Nullable Integer maxInputTokens,
114174
@Nullable String modelId,
115-
@Nullable RateLimitSettings rateLimitSettings
175+
@Nullable RateLimitSettings rateLimitSettings,
176+
CohereApiVersion apiVersion
116177
) {
117-
this(createOptionalUri(url), similarity, dimensions, maxInputTokens, modelId, rateLimitSettings);
178+
this(createOptionalUri(url), similarity, dimensions, maxInputTokens, modelId, rateLimitSettings, apiVersion);
118179
}
119180

120181
public CohereServiceSettings(StreamInput in) throws IOException {
@@ -129,18 +190,29 @@ public CohereServiceSettings(StreamInput in) throws IOException {
129190
} else {
130191
rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS;
131192
}
193+
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
194+
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) {
195+
this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class);
196+
} else {
197+
this.apiVersion = CohereServiceSettings.CohereApiVersion.V1;
198+
}
132199
}
133200

134201
// should only be used for testing, public because it's accessed outside of the package
135-
public CohereServiceSettings() {
136-
this((URI) null, null, null, null, null, null);
202+
public CohereServiceSettings(CohereApiVersion apiVersion) {
203+
this((URI) null, null, null, null, null, null, apiVersion);
137204
}
138205

139206
@Override
140207
public RateLimitSettings rateLimitSettings() {
141208
return rateLimitSettings;
142209
}
143210

211+
@Override
212+
public CohereApiVersion apiVersion() {
213+
return apiVersion;
214+
}
215+
144216
public URI uri() {
145217
return uri;
146218
}
@@ -172,15 +244,14 @@ public String getWriteableName() {
172244
@Override
173245
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
174246
builder.startObject();
175-
176247
toXContentFragment(builder, params);
177-
178248
builder.endObject();
179249
return builder;
180250
}
181251

182252
public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException {
183-
return toXContentFragmentOfExposedFields(builder, params);
253+
toXContentFragmentOfExposedFields(builder, params);
254+
return builder.field(API_VERSION, apiVersion); // API version is persisted but not exposed to the user
184255
}
185256

186257
@Override
@@ -222,6 +293,10 @@ public void writeTo(StreamOutput out) throws IOException {
222293
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
223294
rateLimitSettings.writeTo(out);
224295
}
296+
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
297+
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) {
298+
out.writeEnum(apiVersion);
299+
}
225300
}
226301

227302
@Override
@@ -234,11 +309,12 @@ public boolean equals(Object o) {
234309
&& Objects.equals(dimensions, that.dimensions)
235310
&& Objects.equals(maxInputTokens, that.maxInputTokens)
236311
&& Objects.equals(modelId, that.modelId)
237-
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
312+
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
313+
&& apiVersion == that.apiVersion;
238314
}
239315

240316
@Override
241317
public int hashCode() {
242-
return Objects.hash(uri, similarity, dimensions, maxInputTokens, modelId, rateLimitSettings);
318+
return Objects.hash(uri, similarity, dimensions, maxInputTokens, modelId, rateLimitSettings, apiVersion);
243319
}
244320
}

0 commit comments

Comments
 (0)