Skip to content

Commit d37dd1f

Browse files
committed
Upgrade test
1 parent fcbbfa0 commit d37dd1f

File tree

16 files changed

+150
-120
lines changed

16 files changed

+150
-120
lines changed

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.elasticsearch.common.Strings;
1313
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.test.http.MockRequest;
1415
import org.elasticsearch.test.http.MockResponse;
1516
import org.elasticsearch.test.http.MockWebServer;
1617
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
@@ -36,10 +37,15 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3637
// TODO: replace with proper test features
3738
private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0";
3839
private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0";
40+
private static final String V2_API = "gte_v8.19.0";
3941

4042
private static MockWebServer cohereEmbeddingsServer;
4143
private static MockWebServer cohereRerankServer;
4244

45+
private enum ApiVersion {
46+
V1, V2
47+
}
48+
4349
public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) {
4450
super(upgradedNodes);
4551
}
@@ -62,15 +68,18 @@ public static void shutdown() {
6268
@SuppressWarnings("unchecked")
6369
public void testCohereEmbeddings() throws IOException {
6470
var embeddingsSupported = oldClusterHasFeature(COHERE_EMBEDDINGS_ADDED_TEST_FEATURE);
65-
String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
6671
assumeTrue("Cohere embedding service supported", embeddingsSupported);
6772

73+
String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
74+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(V2_API) ? ApiVersion.V2 : ApiVersion.V1;
75+
6876
final String oldClusterIdInt8 = "old-cluster-embeddings-int8";
6977
final String oldClusterIdFloat = "old-cluster-embeddings-float";
7078

7179
var testTaskType = TaskType.TEXT_EMBEDDING;
7280

7381
if (isOldCluster()) {
82+
7483
// queue a response as PUT will call the service
7584
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
7685
put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
@@ -128,13 +137,17 @@ public void testCohereEmbeddings() throws IOException {
128137

129138
// Inference on old cluster models
130139
assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
140+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion);
131141
assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
142+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion);
132143

133144
{
134145
final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte";
135146

147+
// new endpoints use the V2 API
136148
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
137149
put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType);
150+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
138151

139152
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdByte).get("endpoints");
140153
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
@@ -146,26 +159,31 @@ public void testCohereEmbeddings() throws IOException {
146159
{
147160
final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8";
148161

162+
// new endpoints use the V2 API
149163
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
150164
put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
165+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
151166

152167
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdInt8).get("endpoints");
153168
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
154169
assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte
155170

156171
assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8);
172+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
157173
delete(upgradedClusterIdInt8);
158174
}
159175
{
160176
final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float";
161177
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
162178
put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType);
179+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
163180

164181
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdFloat).get("endpoints");
165182
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
166183
assertThat(serviceSettings, hasEntry("embedding_type", "float"));
167184

168185
assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT);
186+
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
169187
delete(upgradedClusterIdFloat);
170188
}
171189

@@ -174,6 +192,17 @@ public void testCohereEmbeddings() throws IOException {
174192
}
175193
}
176194

195+
private void assertVersionInPath(MockRequest request, String endpoint, ApiVersion apiVersion) {
196+
switch (apiVersion) {
197+
case V2:
198+
assertEquals("/v2/" + endpoint, request.getUri().getPath());
199+
break;
200+
case V1:
201+
assertEquals("/v1/" + endpoint, request.getUri().getPath());
202+
break;
203+
}
204+
}
205+
177206
void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
178207
switch (type) {
179208
case INT8:
@@ -191,9 +220,11 @@ void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) thro
191220
@SuppressWarnings("unchecked")
192221
public void testRerank() throws IOException {
193222
var rerankSupported = oldClusterHasFeature(COHERE_RERANK_ADDED_TEST_FEATURE);
194-
String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
195223
assumeTrue("Cohere rerank service supported", rerankSupported);
196224

225+
String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
226+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(V2_API) ? ApiVersion.V2 : ApiVersion.V1;
227+
197228
final String oldClusterId = "old-cluster-rerank";
198229
final String upgradedClusterId = "upgraded-cluster-rerank";
199230

@@ -216,7 +247,6 @@ public void testRerank() throws IOException {
216247
assertThat(taskSettings, hasEntry("top_n", 3));
217248

218249
assertRerank(oldClusterId);
219-
220250
} else if (isUpgradedCluster()) {
221251
// check old cluster model
222252
var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get("endpoints");
@@ -227,6 +257,7 @@ public void testRerank() throws IOException {
227257
assertThat(taskSettings, hasEntry("top_n", 3));
228258

229259
assertRerank(oldClusterId);
260+
assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", oldClusterApiVersion);
230261

231262
// New endpoint
232263
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
@@ -235,6 +266,7 @@ public void testRerank() throws IOException {
235266
assertThat(configs, hasSize(1));
236267

237268
assertRerank(upgradedClusterId);
269+
assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", ApiVersion.V2);
238270

239271
delete(oldClusterId);
240272
delete(upgradedClusterId);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,35 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere;
99

10-
import org.elasticsearch.common.CheckedSupplier;
10+
import org.apache.http.client.utils.URIBuilder;
11+
import org.elasticsearch.ElasticsearchStatusException;
12+
import org.elasticsearch.common.Strings;
1113
import org.elasticsearch.common.settings.SecureString;
14+
import org.elasticsearch.rest.RestStatus;
15+
import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils;
1216

1317
import java.net.URI;
1418
import java.net.URISyntaxException;
1519
import java.util.Objects;
1620

17-
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
18-
19-
public record CohereAccount(URI uri, SecureString apiKey) {
20-
21-
public static CohereAccount of(CohereModel model, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
22-
var uri = buildUri(model.uri(), "Cohere", uriBuilder);
23-
24-
return new CohereAccount(uri, model.apiKey());
21+
public record CohereAccount(URI baseUri, SecureString apiKey) {
22+
23+
public static CohereAccount of(CohereModel model) {
24+
try {
25+
var uri = model.baseUri() != null ? model.baseUri() : new URIBuilder().setScheme("https").setHost(CohereUtils.HOST).build();
26+
return new CohereAccount(uri, model.apiKey());
27+
} catch (URISyntaxException e) {
28+
// using bad request here so that potentially sensitive URL information does not get logged
29+
throw new ElasticsearchStatusException(
30+
Strings.format("Failed to construct %s URL", CohereService.NAME),
31+
RestStatus.BAD_REQUEST,
32+
e
33+
);
34+
}
2535
}
2636

2737
public CohereAccount {
28-
Objects.requireNonNull(uri);
38+
Objects.requireNonNull(baseUri);
2939
Objects.requireNonNull(apiKey);
3040
}
3141
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public int rateLimitGroupingHash() {
7272
return apiKey().hashCode();
7373
}
7474

75-
public URI uri() {
75+
public URI baseUri() {
7676
return rateLimitServiceSettings.uri();
7777
}
7878
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public ExecutableAction accept(CohereActionVisitor visitor, Map<String, Object>
9090
}
9191

9292
@Override
93-
public URI uri() {
93+
public URI baseUri() {
9494
return getServiceSettings().getCommonSettings().uri();
9595
}
9696
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,23 @@
99

1010
import org.apache.http.HttpHeaders;
1111
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.client.utils.URIBuilder;
1213
import org.apache.http.entity.ByteArrayEntity;
14+
import org.elasticsearch.ElasticsearchStatusException;
1315
import org.elasticsearch.common.Strings;
1416
import org.elasticsearch.core.Nullable;
17+
import org.elasticsearch.rest.RestStatus;
1518
import org.elasticsearch.xcontent.ToXContentObject;
1619
import org.elasticsearch.xcontent.XContentType;
1720
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1821
import org.elasticsearch.xpack.inference.external.request.Request;
1922
import org.elasticsearch.xpack.inference.services.cohere.CohereAccount;
23+
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
2024

2125
import java.net.URI;
26+
import java.net.URISyntaxException;
2227
import java.nio.charset.StandardCharsets;
28+
import java.util.List;
2329
import java.util.Objects;
2430

2531
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
@@ -46,7 +52,7 @@ protected CohereRequest(CohereAccount account, String inferenceEntityId, @Nullab
4652

4753
@Override
4854
public HttpRequest createHttpRequest() {
49-
HttpPost httpPost = new HttpPost(account.uri());
55+
HttpPost httpPost = new HttpPost(getURI());
5056

5157
ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(this).getBytes(StandardCharsets.UTF_8));
5258
httpPost.setEntity(byteEntity);
@@ -68,7 +74,25 @@ public boolean isStreaming() {
6874

6975
@Override
7076
public URI getURI() {
71-
return account.uri();
77+
return buildUri(account.baseUri());
78+
}
79+
80+
/**
81+
* Returns the URL path segments.
82+
* @return List of segments that make up the path of the request.
83+
*/
84+
protected abstract List<String> pathSegments();
85+
86+
private URI buildUri(URI baseUri) {
87+
try {
88+
return new URIBuilder(baseUri).setPathSegments(pathSegments()).build();
89+
} catch (URISyntaxException e) {
90+
throw new ElasticsearchStatusException(
91+
Strings.format("Failed to construct %s URL", CohereService.NAME),
92+
RestStatus.BAD_REQUEST,
93+
e
94+
);
95+
}
7296
}
7397

7498
public String getModelId() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,21 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere.request.v1;
99

10-
import org.apache.http.client.utils.URIBuilder;
1110
import org.elasticsearch.xcontent.XContentBuilder;
1211
import org.elasticsearch.xpack.inference.services.cohere.CohereAccount;
1312
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
1413
import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest;
1514
import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils;
1615

1716
import java.io.IOException;
18-
import java.net.URI;
19-
import java.net.URISyntaxException;
2017
import java.util.List;
2118
import java.util.Objects;
2219

2320
public class CohereV1CompletionRequest extends CohereRequest {
2421
private final List<String> input;
2522

2623
public CohereV1CompletionRequest(List<String> input, CohereCompletionModel model, boolean stream) {
27-
super(
28-
CohereAccount.of(model, CohereV1CompletionRequest::buildDefaultUri),
29-
model.getInferenceEntityId(),
30-
model.getServiceSettings().modelId(),
31-
stream
32-
);
24+
super(CohereAccount.of(model), model.getInferenceEntityId(), model.getServiceSettings().modelId(), stream);
3325

3426
this.input = Objects.requireNonNull(input);
3527
}
@@ -49,10 +41,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4941
return builder;
5042
}
5143

52-
public static URI buildDefaultUri() throws URISyntaxException {
53-
return new URIBuilder().setScheme("https")
54-
.setHost(CohereUtils.HOST)
55-
.setPathSegments(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH)
56-
.build();
44+
@Override
45+
protected List<String> pathSegments() {
46+
return List.of(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH);
5747
}
5848
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere.request.v1;
99

10-
import org.apache.http.client.utils.URIBuilder;
1110
import org.elasticsearch.inference.InputType;
1211
import org.elasticsearch.xcontent.XContentBuilder;
1312
import org.elasticsearch.xpack.inference.services.cohere.CohereAccount;
@@ -20,8 +19,6 @@
2019
import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils;
2120

2221
import java.io.IOException;
23-
import java.net.URI;
24-
import java.net.URISyntaxException;
2522
import java.util.List;
2623
import java.util.Objects;
2724

@@ -34,7 +31,7 @@ public class CohereV1EmbeddingsRequest extends CohereRequest {
3431

3532
public CohereV1EmbeddingsRequest(List<String> input, InputType inputType, CohereEmbeddingsModel embeddingsModel) {
3633
super(
37-
CohereAccount.of(embeddingsModel, CohereV1EmbeddingsRequest::buildDefaultUri),
34+
CohereAccount.of(embeddingsModel),
3835
embeddingsModel.getInferenceEntityId(),
3936
embeddingsModel.getServiceSettings().getCommonSettings().modelId(),
4037
false
@@ -46,11 +43,9 @@ public CohereV1EmbeddingsRequest(List<String> input, InputType inputType, Cohere
4643
embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
4744
}
4845

49-
public static URI buildDefaultUri() throws URISyntaxException {
50-
return new URIBuilder().setScheme("https")
51-
.setHost(CohereUtils.HOST)
52-
.setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH)
53-
.build();
46+
@Override
47+
protected List<String> pathSegments() {
48+
return List.of(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH);
5449
}
5550

5651
@Override

0 commit comments

Comments
 (0)