Skip to content

Commit 6fd99c6

Browse files
jan-elasticelasticsearchmachine
andauthored
Test ML model server (#120270)
* Fix model downloading for very small models. * Test MlModelServer * Tiny ELSER * unmute TextEmbeddingCrudIT and DefaultEndPointsIT * update ELSER * Improve MlModelServer * tiny E5 * more logging * improved E5 model * tiny reranker * scan for ports * [CI] Auto commit changes from spotless * Serve default models when optimized model is requested * @ClassRule * polish code * Respect dynamic setting ML model repo * fix metadata for optimized models * improve logging --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 3d43f73 commit 6fd99c6

File tree

17 files changed

+262
-29
lines changed

17 files changed

+262
-29
lines changed

muted-tests.yml

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,9 @@ tests:
5252
- class: org.elasticsearch.xpack.transform.integration.TransformIT
5353
method: testStopWaitForCheckpoint
5454
issue: https://github.com/elastic/elasticsearch/issues/106113
55-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
56-
method: testPutE5Small_withPlatformAgnosticVariant
57-
issue: https://github.com/elastic/elasticsearch/issues/113983
58-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
59-
method: testPutE5WithTrainedModelAndInference
60-
issue: https://github.com/elastic/elasticsearch/issues/114023
61-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
62-
method: testPutE5Small_withPlatformSpecificVariant
63-
issue: https://github.com/elastic/elasticsearch/issues/113950
6455
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityWithApmTracingRestIT
6556
method: testTracingCrossCluster
6657
issue: https://github.com/elastic/elasticsearch/issues/112731
67-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
68-
method: testInferDeploysDefaultE5
69-
issue: https://github.com/elastic/elasticsearch/issues/115361
7058
- class: org.elasticsearch.xpack.restart.MLModelDeploymentFullClusterRestartIT
7159
method: testDeploymentSurvivesRestart {cluster=UPGRADED}
7260
issue: https://github.com/elastic/elasticsearch/issues/115528
@@ -110,9 +98,6 @@ tests:
11098
- class: org.elasticsearch.xpack.apmdata.APMYamlTestSuiteIT
11199
method: test {yaml=/10_apm/Test template reinstallation}
112100
issue: https://github.com/elastic/elasticsearch/issues/116445
113-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
114-
method: testMultipleInferencesTriggeringDownloadAndDeploy
115-
issue: https://github.com/elastic/elasticsearch/issues/117208
116101
- class: org.elasticsearch.ingest.geoip.EnterpriseGeoIpDownloaderIT
117102
method: testEnterpriseDownloaderTask
118103
issue: https://github.com/elastic/elasticsearch/issues/115163
@@ -125,9 +110,6 @@ tests:
125110
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
126111
method: test {p0=snapshot/10_basic/Create a source only snapshot and then restore it}
127112
issue: https://github.com/elastic/elasticsearch/issues/117295
128-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
129-
method: testInferDeploysDefaultElser
130-
issue: https://github.com/elastic/elasticsearch/issues/114913
131113
- class: org.elasticsearch.xpack.inference.InferenceRestIT
132114
method: test {p0=inference/40_semantic_text_query/Query a field that uses the default ELSER 2 endpoint}
133115
issue: https://github.com/elastic/elasticsearch/issues/117027
@@ -152,9 +134,6 @@ tests:
152134
- class: org.elasticsearch.xpack.remotecluster.CrossClusterEsqlRCS2UnavailableRemotesIT
153135
method: testEsqlRcs2UnavailableRemoteScenarios
154136
issue: https://github.com/elastic/elasticsearch/issues/117419
155-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
156-
method: testInferDeploysDefaultRerank
157-
issue: https://github.com/elastic/elasticsearch/issues/118184
158137
- class: org.elasticsearch.xpack.esql.action.EsqlActionTaskIT
159138
method: testCancelRequestWhenFailingFetchingPages
160139
issue: https://github.com/elastic/elasticsearch/issues/118193

x-pack/plugin/inference/qa/inference-service-tests/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
apply plugin: 'elasticsearch.internal-java-rest-test'
22

33
dependencies {
4+
javaRestTestImplementation project(path: xpackModule('core'))
45
javaRestTestImplementation project(path: xpackModule('inference'))
56
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
67
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.xcontent.XContentFactory;
2727
import org.elasticsearch.xcontent.XContentType;
2828
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
29+
import org.junit.Before;
2930
import org.junit.ClassRule;
3031

3132
import java.io.IOException;
@@ -41,6 +42,7 @@
4142
import static org.hamcrest.Matchers.hasSize;
4243

4344
public class InferenceBaseRestTest extends ESRestTestCase {
45+
4446
@ClassRule
4547
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
4648
.distribution(DistributionType.DEFAULT)
@@ -51,6 +53,22 @@ public class InferenceBaseRestTest extends ESRestTestCase {
5153
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
5254
.build();
5355

56+
@ClassRule
57+
public static MlModelServer mlModelServer = new MlModelServer();
58+
59+
@Before
60+
public void setMlModelRepository() throws IOException {
61+
logger.info("setting ML model repository to: {}", mlModelServer.getUrl());
62+
var request = new Request("PUT", "/_cluster/settings");
63+
request.setJsonEntity(Strings.format("""
64+
{
65+
"persistent": {
66+
"xpack.ml.model_repository": "%s"
67+
}
68+
}""", mlModelServer.getUrl()));
69+
assertOK(client().performRequest(request));
70+
}
71+
5472
@Override
5573
protected String getTestRestCluster() {
5674
return cluster.getHttpAddresses();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import com.sun.net.httpserver.HttpExchange;
11+
import com.sun.net.httpserver.HttpServer;
12+
13+
import org.apache.http.HttpHeaders;
14+
import org.apache.http.HttpStatus;
15+
import org.apache.http.client.utils.URIBuilder;
16+
import org.elasticsearch.logging.LogManager;
17+
import org.elasticsearch.logging.Logger;
18+
import org.elasticsearch.test.fixture.HttpHeaderParser;
19+
import org.elasticsearch.xcontent.XContentParser;
20+
import org.elasticsearch.xcontent.XContentParserConfiguration;
21+
import org.elasticsearch.xcontent.XContentType;
22+
import org.elasticsearch.xpack.core.XPackSettings;
23+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
24+
import org.junit.rules.TestRule;
25+
import org.junit.runner.Description;
26+
import org.junit.runners.model.Statement;
27+
28+
import java.io.ByteArrayInputStream;
29+
import java.io.IOException;
30+
import java.io.InputStream;
31+
import java.io.OutputStream;
32+
import java.net.InetSocketAddress;
33+
import java.nio.charset.StandardCharsets;
34+
import java.util.Random;
35+
import java.util.concurrent.ExecutorService;
36+
import java.util.concurrent.Executors;
37+
38+
/**
39+
* Simple model server to serve ML models.
40+
* The URL path corresponds to a file name in this class's resources.
41+
* If the file is found, its content is returned, otherwise 404.
42+
* Respects a range header to serve partial content.
43+
*/
44+
public class MlModelServer implements TestRule {
45+
46+
private static final String HOST = "localhost";
47+
private static final Logger logger = LogManager.getLogger(MlModelServer.class);
48+
49+
private int port;
50+
51+
public String getUrl() {
52+
return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString();
53+
}
54+
55+
private void handle(HttpExchange exchange) throws IOException {
56+
String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE);
57+
HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null;
58+
logger.info("request: {} range={}", exchange.getRequestURI().getPath(), range);
59+
60+
try (InputStream is = getInputStream(exchange)) {
61+
int httpStatus;
62+
long numBytes;
63+
if (is == null) {
64+
httpStatus = HttpStatus.SC_NOT_FOUND;
65+
numBytes = 0;
66+
} else if (range == null) {
67+
httpStatus = HttpStatus.SC_OK;
68+
numBytes = is.available();
69+
} else {
70+
httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
71+
is.skipNBytes(range.start());
72+
numBytes = range.end() - range.start() + 1;
73+
}
74+
logger.info("response: {} {}", exchange.getRequestURI().getPath(), httpStatus);
75+
exchange.sendResponseHeaders(httpStatus, numBytes);
76+
try (OutputStream os = exchange.getResponseBody()) {
77+
while (numBytes > 0) {
78+
byte[] bytes = is.readNBytes((int) Math.min(1 << 20, numBytes));
79+
os.write(bytes);
80+
numBytes -= bytes.length;
81+
}
82+
}
83+
}
84+
}
85+
86+
private InputStream getInputStream(HttpExchange exchange) throws IOException {
87+
String path = exchange.getRequestURI().getPath().substring(1); // Strip leading slash
88+
String modelId = path.substring(0, path.indexOf('.'));
89+
String extension = path.substring(path.indexOf('.') + 1);
90+
91+
// If a model specifically optimized for some platform is requested,
92+
// serve the default non-optimized model instead, which is compatible.
93+
String defaultModelId = modelId;
94+
for (String platform : XPackSettings.ML_NATIVE_CODE_PLATFORMS) {
95+
defaultModelId = defaultModelId.replace("_" + platform, "");
96+
}
97+
98+
ClassLoader classloader = Thread.currentThread().getContextClassLoader();
99+
InputStream is = classloader.getResourceAsStream(defaultModelId + "." + extension);
100+
if (is != null && modelId.equals(defaultModelId) == false && extension.equals("metadata.json")) {
101+
// When an optimized version is requested, fix the default metadata,
102+
// so that it contains the correct model ID.
103+
try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, is.readAllBytes())) {
104+
is.close();
105+
ModelPackageConfig packageConfig = ModelPackageConfig.fromXContentLenient(parser);
106+
packageConfig = new ModelPackageConfig.Builder(packageConfig).setPackedModelId(modelId).build();
107+
is = new ByteArrayInputStream(packageConfig.toString().getBytes(StandardCharsets.UTF_8));
108+
}
109+
}
110+
return is;
111+
}
112+
113+
@Override
114+
public Statement apply(Statement statement, Description description) {
115+
return new Statement() {
116+
@Override
117+
public void evaluate() throws Throwable {
118+
logger.info("Starting ML model server");
119+
HttpServer server = HttpServer.create();
120+
while (true) {
121+
port = new Random().nextInt(10000, 65536);
122+
try {
123+
server.bind(new InetSocketAddress(HOST, port), 1);
124+
} catch (Exception e) {
125+
continue;
126+
}
127+
break;
128+
}
129+
logger.info("Bound ML model server to port {}", port);
130+
131+
ExecutorService executor = Executors.newCachedThreadPool();
132+
server.setExecutor(executor);
133+
server.createContext("/", MlModelServer.this::handle);
134+
server.start();
135+
136+
try {
137+
statement.evaluate();
138+
} finally {
139+
logger.info("Stopping ML model server on port {}", port);
140+
server.stop(1);
141+
executor.shutdown();
142+
}
143+
}
144+
};
145+
}
146+
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import static org.hamcrest.Matchers.containsString;
2020

21-
// This test was previously disabled in CI due to the models being too large
22-
// See "https://github.com/elastic/elasticsearch/issues/105198".
2321
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
2422

2523
public void testPutE5Small_withNoModelVariant() {
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"packaged_model_id": "elser_model_2",
3+
"minimum_version": "11.0.0",
4+
"size": 1859242,
5+
"sha256": "602dbccfb2746e5700bf65d8019b06fb2ec1e3c5bfb980eb2005fc17c1bfe0c0",
6+
"description": "Elastic Learned Sparse EncodeR v2",
7+
"model_type": "pytorch",
8+
"tags": [
9+
"elastic"
10+
],
11+
"inference_config": {
12+
"text_expansion": {
13+
"tokenization": {
14+
"bert": {
15+
"do_lower_case": true,
16+
"with_special_tokens": true,
17+
"max_sequence_length": 512,
18+
"truncate": "first",
19+
"span": -1
20+
}
21+
}
22+
}
23+
},
24+
"vocabulary_file": "elser_model_2.vocab.json"
25+
}
Binary file not shown.

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/elser_model_2.vocab.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"packaged_model_id": "multilingual-e5-small",
3+
"minimum_version": "12.0.0",
4+
"size": 5531160,
5+
"sha256": "92e24566eff554d3a6808cc62731dbecf32db63e01801f3f62210aa9131c7a8b",
6+
"description": "E5 small multilingual",
7+
"model_type": "pytorch",
8+
"tags": [],
9+
"inference_config": {
10+
"text_embedding": {
11+
"tokenization": {
12+
"xlm_roberta": {
13+
"do_lower_case": false,
14+
"with_special_tokens": true,
15+
"max_sequence_length": 512,
16+
"truncate": "first",
17+
"span": -1
18+
}
19+
},
20+
"embedding_size": 384
21+
}
22+
},
23+
"prefix_strings": {
24+
"search": "query: ",
25+
"ingest": "passage: "
26+
},
27+
"metadata": {
28+
"per_allocation_memory_bytes": 557785256,
29+
"per_deployment_memory_bytes": 470031872
30+
},
31+
"vocabulary_file": "multilingual-e5-small.vocab.json"
32+
}

0 commit comments

Comments
 (0)