Skip to content

Commit 8ae48ca

Browse files
jan-elasticelasticsearchmachine
andcommitted
Test ML model server (#120270)
* Fix model downloading for very small models. * Test MlModelServer * Tiny ELSER * unmute TextEmbeddingCrudIT and DefaultEndPointsIT * update ELSER * Improve MlModelServer * tiny E5 * more logging * improved E5 model * tiny reranker * scan for ports * [CI] Auto commit changes from spotless * Serve default models when optimized model is requested * @ClassRule * polish code * Respect dynamic setting ML model repo * fix metadata for optimized models * improve logging --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 582c844 commit 8ae48ca

File tree

17 files changed

+262
-26
lines changed

17 files changed

+262
-26
lines changed

muted-tests.yml

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,6 @@ tests:
199199
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
200200
method: test {categorize.Categorize ASYNC}
201201
issue: https://github.com/elastic/elasticsearch/issues/116373
202-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
203-
method: testPutE5WithTrainedModelAndInference
204-
issue: https://github.com/elastic/elasticsearch/issues/114023
205-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
206-
method: testPutE5Small_withPlatformAgnosticVariant
207-
issue: https://github.com/elastic/elasticsearch/issues/113983
208202
- class: org.elasticsearch.datastreams.LazyRolloverDuringDisruptionIT
209203
method: testRolloverIsExecutedOnce
210204
issue: https://github.com/elastic/elasticsearch/issues/112634
@@ -214,9 +208,6 @@ tests:
214208
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityWithApmTracingRestIT
215209
method: testTracingCrossCluster
216210
issue: https://github.com/elastic/elasticsearch/issues/112731
217-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
218-
method: testPutE5Small_withPlatformSpecificVariant
219-
issue: https://github.com/elastic/elasticsearch/issues/113950
220211
- class: org.elasticsearch.smoketest.DocsClientYamlTestSuiteIT
221212
method: test {yaml=reference/rest-api/usage/line_38}
222213
issue: https://github.com/elastic/elasticsearch/issues/113694
@@ -226,9 +217,6 @@ tests:
226217
- class: org.elasticsearch.reservedstate.service.FileSettingsServiceTests
227218
method: testProcessFileChanges
228219
issue: https://github.com/elastic/elasticsearch/issues/115280
229-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
230-
method: testInferDeploysDefaultE5
231-
issue: https://github.com/elastic/elasticsearch/issues/115361
232220
- class: org.elasticsearch.xpack.inference.InferenceCrudIT
233221
method: testSupportedStream
234222
issue: https://github.com/elastic/elasticsearch/issues/113430
@@ -285,9 +273,6 @@ tests:
285273
- class: org.elasticsearch.xpack.esql.qa.mixed.EsqlClientYamlIT
286274
method: test {p0=esql/61_enrich_ip/IP strings}
287275
issue: https://github.com/elastic/elasticsearch/issues/116529
288-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
289-
method: testInferDeploysDefaultElser
290-
issue: https://github.com/elastic/elasticsearch/issues/114913
291276
- class: org.elasticsearch.threadpool.SimpleThreadPoolIT
292277
method: testThreadPoolMetrics
293278
issue: https://github.com/elastic/elasticsearch/issues/108320
@@ -336,9 +321,6 @@ tests:
336321
- class: org.elasticsearch.xpack.searchablesnapshots.RetrySearchIntegTests
337322
method: testRetryPointInTime
338323
issue: https://github.com/elastic/elasticsearch/issues/117116
339-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
340-
method: testMultipleInferencesTriggeringDownloadAndDeploy
341-
issue: https://github.com/elastic/elasticsearch/issues/117208
342324
- class: org.elasticsearch.xpack.spatial.search.GeoGridAggAndQueryConsistencyIT
343325
method: testGeoPointGeoTile
344326
issue: https://github.com/elastic/elasticsearch/issues/115818

x-pack/plugin/inference/qa/inference-service-tests/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
apply plugin: 'elasticsearch.internal-java-rest-test'
22

33
dependencies {
4+
javaRestTestImplementation project(path: xpackModule('core'))
45
javaRestTestImplementation project(path: xpackModule('inference'))
56
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
67
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
2323
import org.elasticsearch.test.rest.ESRestTestCase;
2424
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
25+
import org.junit.Before;
2526
import org.junit.ClassRule;
2627

2728
import java.io.IOException;
@@ -37,6 +38,7 @@
3738
import static org.hamcrest.Matchers.hasSize;
3839

3940
public class InferenceBaseRestTest extends ESRestTestCase {
41+
4042
@ClassRule
4143
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
4244
.distribution(DistributionType.DEFAULT)
@@ -46,6 +48,22 @@ public class InferenceBaseRestTest extends ESRestTestCase {
4648
.user("x_pack_rest_user", "x-pack-test-password")
4749
.build();
4850

51+
@ClassRule
52+
public static MlModelServer mlModelServer = new MlModelServer();
53+
54+
@Before
55+
public void setMlModelRepository() throws IOException {
56+
logger.info("setting ML model repository to: {}", mlModelServer.getUrl());
57+
var request = new Request("PUT", "/_cluster/settings");
58+
request.setJsonEntity(Strings.format("""
59+
{
60+
"persistent": {
61+
"xpack.ml.model_repository": "%s"
62+
}
63+
}""", mlModelServer.getUrl()));
64+
assertOK(client().performRequest(request));
65+
}
66+
4967
@Override
5068
protected String getTestRestCluster() {
5169
return cluster.getHttpAddresses();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import com.sun.net.httpserver.HttpExchange;
11+
import com.sun.net.httpserver.HttpServer;
12+
13+
import org.apache.http.HttpHeaders;
14+
import org.apache.http.HttpStatus;
15+
import org.apache.http.client.utils.URIBuilder;
16+
import org.elasticsearch.logging.LogManager;
17+
import org.elasticsearch.logging.Logger;
18+
import org.elasticsearch.test.fixture.HttpHeaderParser;
19+
import org.elasticsearch.xcontent.XContentParser;
20+
import org.elasticsearch.xcontent.XContentParserConfiguration;
21+
import org.elasticsearch.xcontent.XContentType;
22+
import org.elasticsearch.xpack.core.XPackSettings;
23+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
24+
import org.junit.rules.TestRule;
25+
import org.junit.runner.Description;
26+
import org.junit.runners.model.Statement;
27+
28+
import java.io.ByteArrayInputStream;
29+
import java.io.IOException;
30+
import java.io.InputStream;
31+
import java.io.OutputStream;
32+
import java.net.InetSocketAddress;
33+
import java.nio.charset.StandardCharsets;
34+
import java.util.Random;
35+
import java.util.concurrent.ExecutorService;
36+
import java.util.concurrent.Executors;
37+
38+
/**
39+
* Simple model server to serve ML models.
40+
* The URL path corresponds to a file name in this class's resources.
41+
* If the file is found, its content is returned, otherwise 404.
42+
* Respects a range header to serve partial content.
43+
*/
44+
public class MlModelServer implements TestRule {
45+
46+
private static final String HOST = "localhost";
47+
private static final Logger logger = LogManager.getLogger(MlModelServer.class);
48+
49+
private int port;
50+
51+
public String getUrl() {
52+
return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString();
53+
}
54+
55+
private void handle(HttpExchange exchange) throws IOException {
56+
String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE);
57+
HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null;
58+
logger.info("request: {} range={}", exchange.getRequestURI().getPath(), range);
59+
60+
try (InputStream is = getInputStream(exchange)) {
61+
int httpStatus;
62+
long numBytes;
63+
if (is == null) {
64+
httpStatus = HttpStatus.SC_NOT_FOUND;
65+
numBytes = 0;
66+
} else if (range == null) {
67+
httpStatus = HttpStatus.SC_OK;
68+
numBytes = is.available();
69+
} else {
70+
httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
71+
is.skipNBytes(range.start());
72+
numBytes = range.end() - range.start() + 1;
73+
}
74+
logger.info("response: {} {}", exchange.getRequestURI().getPath(), httpStatus);
75+
exchange.sendResponseHeaders(httpStatus, numBytes);
76+
try (OutputStream os = exchange.getResponseBody()) {
77+
while (numBytes > 0) {
78+
byte[] bytes = is.readNBytes((int) Math.min(1 << 20, numBytes));
79+
os.write(bytes);
80+
numBytes -= bytes.length;
81+
}
82+
}
83+
}
84+
}
85+
86+
private InputStream getInputStream(HttpExchange exchange) throws IOException {
87+
String path = exchange.getRequestURI().getPath().substring(1); // Strip leading slash
88+
String modelId = path.substring(0, path.indexOf('.'));
89+
String extension = path.substring(path.indexOf('.') + 1);
90+
91+
// If a model specifically optimized for some platform is requested,
92+
// serve the default non-optimized model instead, which is compatible.
93+
String defaultModelId = modelId;
94+
for (String platform : XPackSettings.ML_NATIVE_CODE_PLATFORMS) {
95+
defaultModelId = defaultModelId.replace("_" + platform, "");
96+
}
97+
98+
ClassLoader classloader = Thread.currentThread().getContextClassLoader();
99+
InputStream is = classloader.getResourceAsStream(defaultModelId + "." + extension);
100+
if (is != null && modelId.equals(defaultModelId) == false && extension.equals("metadata.json")) {
101+
// When an optimized version is requested, fix the default metadata,
102+
// so that it contains the correct model ID.
103+
try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, is.readAllBytes())) {
104+
is.close();
105+
ModelPackageConfig packageConfig = ModelPackageConfig.fromXContentLenient(parser);
106+
packageConfig = new ModelPackageConfig.Builder(packageConfig).setPackedModelId(modelId).build();
107+
is = new ByteArrayInputStream(packageConfig.toString().getBytes(StandardCharsets.UTF_8));
108+
}
109+
}
110+
return is;
111+
}
112+
113+
@Override
114+
public Statement apply(Statement statement, Description description) {
115+
return new Statement() {
116+
@Override
117+
public void evaluate() throws Throwable {
118+
logger.info("Starting ML model server");
119+
HttpServer server = HttpServer.create();
120+
while (true) {
121+
port = new Random().nextInt(10000, 65536);
122+
try {
123+
server.bind(new InetSocketAddress(HOST, port), 1);
124+
} catch (Exception e) {
125+
continue;
126+
}
127+
break;
128+
}
129+
logger.info("Bound ML model server to port {}", port);
130+
131+
ExecutorService executor = Executors.newCachedThreadPool();
132+
server.setExecutor(executor);
133+
server.createContext("/", MlModelServer.this::handle);
134+
server.start();
135+
136+
try {
137+
statement.evaluate();
138+
} finally {
139+
logger.info("Stopping ML model server on port {}", port);
140+
server.stop(1);
141+
executor.shutdown();
142+
}
143+
}
144+
};
145+
}
146+
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import static org.hamcrest.Matchers.containsString;
2020

21-
// This test was previously disabled in CI due to the models being too large
22-
// See "https://github.com/elastic/elasticsearch/issues/105198".
2321
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
2422

2523
public void testPutE5Small_withNoModelVariant() {
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"packaged_model_id": "elser_model_2",
3+
"minimum_version": "11.0.0",
4+
"size": 1859242,
5+
"sha256": "602dbccfb2746e5700bf65d8019b06fb2ec1e3c5bfb980eb2005fc17c1bfe0c0",
6+
"description": "Elastic Learned Sparse EncodeR v2",
7+
"model_type": "pytorch",
8+
"tags": [
9+
"elastic"
10+
],
11+
"inference_config": {
12+
"text_expansion": {
13+
"tokenization": {
14+
"bert": {
15+
"do_lower_case": true,
16+
"with_special_tokens": true,
17+
"max_sequence_length": 512,
18+
"truncate": "first",
19+
"span": -1
20+
}
21+
}
22+
}
23+
},
24+
"vocabulary_file": "elser_model_2.vocab.json"
25+
}
Binary file not shown.

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/elser_model_2.vocab.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"packaged_model_id": "multilingual-e5-small",
3+
"minimum_version": "12.0.0",
4+
"size": 5531160,
5+
"sha256": "92e24566eff554d3a6808cc62731dbecf32db63e01801f3f62210aa9131c7a8b",
6+
"description": "E5 small multilingual",
7+
"model_type": "pytorch",
8+
"tags": [],
9+
"inference_config": {
10+
"text_embedding": {
11+
"tokenization": {
12+
"xlm_roberta": {
13+
"do_lower_case": false,
14+
"with_special_tokens": true,
15+
"max_sequence_length": 512,
16+
"truncate": "first",
17+
"span": -1
18+
}
19+
},
20+
"embedding_size": 384
21+
}
22+
},
23+
"prefix_strings": {
24+
"search": "query: ",
25+
"ingest": "passage: "
26+
},
27+
"metadata": {
28+
"per_allocation_memory_bytes": 557785256,
29+
"per_deployment_memory_bytes": 470031872
30+
},
31+
"vocabulary_file": "multilingual-e5-small.vocab.json"
32+
}

0 commit comments

Comments
 (0)