Skip to content

Commit 1fa1ba7

Browse files
[ML] Add default Elastic Inference Service chat completion endpoint (#120847)
* Starting new auth class implementation * Fixing some tests * Working tests * Refactoring * Addressing feedback and pull main
1 parent 8e2044d commit 1fa1ba7

File tree

15 files changed

+732
-137
lines changed

15 files changed

+732
-137
lines changed

server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.Objects;
2222

2323
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
24+
import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION;
2425
import static org.elasticsearch.inference.TaskType.COMPLETION;
2526
import static org.elasticsearch.inference.TaskType.RERANK;
2627
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
@@ -97,6 +98,10 @@ public static MinimalServiceSettings completion() {
9798
return new MinimalServiceSettings(COMPLETION, null, null, null);
9899
}
99100

101+
public static MinimalServiceSettings chatCompletion() {
102+
return new MinimalServiceSettings(CHAT_COMPLETION, null, null, null);
103+
}
104+
100105
public MinimalServiceSettings(Model model) {
101106
this(
102107
model.getTaskType(),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*
7+
* this file has been contributed to by a Generative AI
8+
*/
9+
10+
package org.elasticsearch.xpack.inference;
11+
12+
import org.elasticsearch.common.settings.SecureString;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.common.util.concurrent.ThreadContext;
15+
import org.elasticsearch.core.TimeValue;
16+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
17+
import org.elasticsearch.test.cluster.FeatureFlag;
18+
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
19+
import org.elasticsearch.test.rest.ESRestTestCase;
20+
import org.junit.ClassRule;
21+
import org.junit.Rule;
22+
import org.junit.rules.RuleChain;
23+
import org.junit.rules.TestRule;
24+
25+
public class BaseMockEISAuthServerTest extends ESRestTestCase {
26+
27+
// The reason we're retrying is there's a race condition between the node retrieving the
28+
// authorization response and running the test. Retrieving the authorization should be very fast since
29+
// we're hosting a local mock server but it's possible it could respond slower. So in the even of a test failure
30+
// we'll automatically retry after waiting a second.
31+
@Rule
32+
public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1));
33+
34+
private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer
35+
.enabledWithRainbowSprinklesAndElser();
36+
37+
private static final ElasticsearchCluster cluster = ElasticsearchCluster.local()
38+
.distribution(DistributionType.DEFAULT)
39+
.setting("xpack.license.self_generated.type", "trial")
40+
.setting("xpack.security.enabled", "true")
41+
// Adding both settings unless one feature flag is disabled in a particular environment
42+
.setting("xpack.inference.elastic.url", mockEISServer::getUrl)
43+
// TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL
44+
.setting("xpack.inference.eis.gateway.url", mockEISServer::getUrl)
45+
// This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin
46+
.plugin("inference-service-test")
47+
.user("x_pack_rest_user", "x-pack-test-password")
48+
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
49+
.build();
50+
51+
// The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating
52+
// it to the cluster as a setting.
53+
@ClassRule
54+
public static TestRule ruleChain = RuleChain.outerRule(mockEISServer).around(cluster);
55+
56+
@Override
57+
protected String getTestRestCluster() {
58+
return cluster.getHttpAddresses();
59+
}
60+
61+
@Override
62+
protected Settings restClientSettings() {
63+
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
64+
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
65+
}
66+
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,20 @@ static String mockDenseServiceModelConfig() {
171171
""";
172172
}
173173

174-
protected void deleteModel(String modelId) throws IOException {
174+
static void deleteModel(String modelId) throws IOException {
175175
var request = new Request("DELETE", "_inference/" + modelId);
176176
var response = client().performRequest(request);
177177
assertStatusOkOrCreated(response);
178178
}
179179

180-
protected Response deleteModel(String modelId, String queryParams) throws IOException {
180+
static Response deleteModel(String modelId, String queryParams) throws IOException {
181181
var request = new Request("DELETE", "_inference/" + modelId + "?" + queryParams);
182182
var response = client().performRequest(request);
183183
assertStatusOkOrCreated(response);
184184
return response;
185185
}
186186

187-
protected void deleteModel(String modelId, TaskType taskType) throws IOException {
187+
static void deleteModel(String modelId, TaskType taskType) throws IOException {
188188
var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, modelId));
189189
var response = client().performRequest(request);
190190
assertStatusOkOrCreated(response);
@@ -229,12 +229,12 @@ protected void putSemanticText(String endpointId, String searchEndpointId, Strin
229229
assertStatusOkOrCreated(response);
230230
}
231231

232-
protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
232+
static Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
233233
String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
234234
return putRequest(endpoint, modelConfig);
235235
}
236236

237-
protected Map<String, Object> updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException {
237+
static Map<String, Object> updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException {
238238
String endpoint = Strings.format("_inference/%s/%s/_update", taskType, inferenceID);
239239
return putRequest(endpoint, modelConfig);
240240
}
@@ -265,12 +265,12 @@ protected void deletePipeline(String pipelineId) throws IOException {
265265
/**
266266
* Task type should be in modelConfig
267267
*/
268-
protected Map<String, Object> putModel(String modelId, String modelConfig) throws IOException {
268+
static Map<String, Object> putModel(String modelId, String modelConfig) throws IOException {
269269
String endpoint = Strings.format("_inference/%s", modelId);
270270
return putRequest(endpoint, modelConfig);
271271
}
272272

273-
Map<String, Object> putRequest(String endpoint, String body) throws IOException {
273+
static Map<String, Object> putRequest(String endpoint, String body) throws IOException {
274274
var request = new Request("PUT", endpoint);
275275
request.setJsonEntity(body);
276276
var response = client().performRequest(request);
@@ -318,18 +318,17 @@ protected Map<String, Object> getModel(String modelId) throws IOException {
318318
}
319319

320320
@SuppressWarnings("unchecked")
321-
protected List<Map<String, Object>> getModels(String modelId, TaskType taskType) throws IOException {
321+
static List<Map<String, Object>> getModels(String modelId, TaskType taskType) throws IOException {
322322
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
323323
return (List<Map<String, Object>>) getInternalAsMap(endpoint).get("endpoints");
324324
}
325325

326326
@SuppressWarnings("unchecked")
327-
protected List<Map<String, Object>> getAllModels() throws IOException {
328-
var endpoint = Strings.format("_inference/_all");
327+
static List<Map<String, Object>> getAllModels() throws IOException {
329328
return (List<Map<String, Object>>) getInternalAsMap("_inference/_all").get("endpoints");
330329
}
331330

332-
private Map<String, Object> getInternalAsMap(String endpoint) throws IOException {
331+
private static Map<String, Object> getInternalAsMap(String endpoint) throws IOException {
333332
var request = new Request("GET", endpoint);
334333
var response = client().performRequest(request);
335334
assertStatusOkOrCreated(response);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*
7+
* this file has been contributed to by a Generative AI
8+
*/
9+
10+
package org.elasticsearch.xpack.inference;
11+
12+
import org.elasticsearch.inference.TaskType;
13+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;
14+
15+
import java.io.IOException;
16+
17+
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels;
18+
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels;
19+
import static org.hamcrest.Matchers.hasSize;
20+
21+
public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest {
22+
23+
public void testGetDefaultEndpoints() throws IOException {
24+
var allModels = getAllModels();
25+
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
26+
27+
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
28+
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
29+
assertThat(allModels, hasSize(4));
30+
assertThat(chatCompletionModels, hasSize(1));
31+
32+
for (var model : chatCompletionModels) {
33+
assertEquals("chat_completion", model.get("task_type"));
34+
}
35+
} else {
36+
assertThat(allModels, hasSize(3));
37+
assertThat(chatCompletionModels, hasSize(0));
38+
}
39+
40+
}
41+
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,8 @@
1111

1212
import org.elasticsearch.client.Request;
1313
import org.elasticsearch.common.Strings;
14-
import org.elasticsearch.common.settings.SecureString;
15-
import org.elasticsearch.common.settings.Settings;
16-
import org.elasticsearch.common.util.concurrent.ThreadContext;
17-
import org.elasticsearch.core.TimeValue;
1814
import org.elasticsearch.inference.TaskType;
19-
import org.elasticsearch.test.cluster.ElasticsearchCluster;
20-
import org.elasticsearch.test.cluster.FeatureFlag;
21-
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
22-
import org.elasticsearch.test.rest.ESRestTestCase;
2315
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;
24-
import org.junit.ClassRule;
25-
import org.junit.Rule;
26-
import org.junit.rules.RuleChain;
27-
import org.junit.rules.TestRule;
2816

2917
import java.io.IOException;
3018
import java.util.ArrayList;
@@ -35,47 +23,7 @@
3523
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
3624
import static org.hamcrest.Matchers.equalTo;
3725

38-
public class InferenceGetServicesIT extends ESRestTestCase {
39-
40-
// The reason we're retrying is there's a race condition between the node retrieving the
41-
// authorization response and running the test. Retrieving the authorization should be very fast since
42-
// we're hosting a local mock server but it's possible it could respond slower. So in the even of a test failure
43-
// we'll automatically retry after waiting a second.
44-
@Rule
45-
public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1));
46-
47-
private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer
48-
.enabledWithSparseEmbeddingsAndChatCompletion();
49-
50-
private static final ElasticsearchCluster cluster = ElasticsearchCluster.local()
51-
.distribution(DistributionType.DEFAULT)
52-
.setting("xpack.license.self_generated.type", "trial")
53-
.setting("xpack.security.enabled", "true")
54-
// Adding both settings unless one feature flag is disabled in a particular environment
55-
.setting("xpack.inference.elastic.url", mockEISServer::getUrl)
56-
// TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL
57-
.setting("xpack.inference.eis.gateway.url", mockEISServer::getUrl)
58-
// This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin
59-
.plugin("inference-service-test")
60-
.user("x_pack_rest_user", "x-pack-test-password")
61-
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
62-
.build();
63-
64-
// The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating
65-
// it to the cluster as a setting.
66-
@ClassRule
67-
public static TestRule ruleChain = RuleChain.outerRule(mockEISServer).around(cluster);
68-
69-
@Override
70-
protected String getTestRestCluster() {
71-
return cluster.getHttpAddresses();
72-
}
73-
74-
@Override
75-
protected Settings restClientSettings() {
76-
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
77-
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
78-
}
26+
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
7927

8028
@SuppressWarnings("unchecked")
8129
public void testGetServicesWithoutTaskType() throws IOException {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,19 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule
2323
private static final Logger logger = LogManager.getLogger(MockElasticInferenceServiceAuthorizationServer.class);
2424
private final MockWebServer webServer = new MockWebServer();
2525

26-
public static MockElasticInferenceServiceAuthorizationServer enabledWithSparseEmbeddingsAndChatCompletion() {
26+
public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowSprinklesAndElser() {
2727
var server = new MockElasticInferenceServiceAuthorizationServer();
2828

2929
String responseJson = """
3030
{
3131
"models": [
3232
{
33-
"model_name": "model-a",
34-
"task_types": ["embed/text/sparse", "chat"]
33+
"model_name": "rainbow-sprinkles",
34+
"task_types": ["chat"]
35+
},
36+
{
37+
"model_name": "elser-v2",
38+
"task_types": ["embed/text/sparse"]
3539
}
3640
]
3741
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ protected void masterOperation(
108108
return;
109109
}
110110

111+
if (modelRegistry.containsDefaultConfigId(request.getInferenceEntityId())) {
112+
listener.onFailure(
113+
new ElasticsearchStatusException(
114+
"[{}] is a reserved inference ID. Cannot create a new inference endpoint with a reserved ID.",
115+
RestStatus.BAD_REQUEST,
116+
request.getInferenceEntityId()
117+
)
118+
);
119+
return;
120+
}
121+
111122
var requestAsMap = requestToMap(request);
112123
var resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME));
113124

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,16 @@ public ModelRegistry(Client client) {
114114
defaultConfigIds = new HashMap<>();
115115
}
116116

117+
/**
118+
* Returns true if the provided inference entity id is the same as one of the default
119+
* endpoints ids.
120+
* @param inferenceEntityId the id to search for
121+
* @return true if we find a match and false if not
122+
*/
123+
public boolean containsDefaultConfigId(String inferenceEntityId) {
124+
return defaultConfigIds.containsKey(inferenceEntityId);
125+
}
126+
117127
/**
118128
* Set the default inference ids provided by the services
119129
* @param defaultConfigId The default

0 commit comments

Comments
 (0)