Skip to content

Commit af8f5d2

Browse files
committed
@ClassRule
1 parent 4742325 commit af8f5d2

File tree

4 files changed

+59
-74
lines changed

4 files changed

+59
-74
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
1717
import org.hamcrest.Matchers;
1818
import org.junit.After;
19-
import org.junit.AfterClass;
2019
import org.junit.Before;
21-
import org.junit.BeforeClass;
2220

2321
import java.io.IOException;
2422
import java.util.ArrayList;
@@ -33,16 +31,6 @@
3331

3432
public class DefaultEndPointsIT extends InferenceBaseRestTest {
3533

36-
@BeforeClass
37-
public static void startModelServer() {
38-
mlModelServer.start();
39-
}
40-
41-
@AfterClass
42-
public static void stopModelServer() {
43-
mlModelServer.stop();
44-
}
45-
4634
private TestThreadPool threadPool;
4735

4836
@Before

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.xcontent.XContentFactory;
2727
import org.elasticsearch.xcontent.XContentType;
2828
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
29+
import org.junit.Before;
2930
import org.junit.ClassRule;
3031

3132
import java.io.IOException;
@@ -42,20 +43,31 @@
4243

4344
public class InferenceBaseRestTest extends ESRestTestCase {
4445

45-
@ClassRule(order = 0)
46-
public static MlModelServer mlModelServer = new MlModelServer();
47-
48-
@ClassRule(order = 1)
46+
@ClassRule
4947
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
5048
.distribution(DistributionType.DEFAULT)
5149
.setting("xpack.license.self_generated.type", "trial")
5250
.setting("xpack.security.enabled", "true")
53-
.setting("xpack.ml.model_repository", "http://localhost:" + mlModelServer.getPort())
5451
.plugin("inference-service-test")
5552
.user("x_pack_rest_user", "x-pack-test-password")
5653
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
5754
.build();
5855

56+
@ClassRule
57+
public static MlModelServer mlModelServer = new MlModelServer();
58+
59+
@Before
60+
public void setMlModelRepository() throws IOException {
61+
var request = new Request("PUT", "/_cluster/settings");
62+
request.setJsonEntity(Strings.format("""
63+
{
64+
"persistent": {
65+
"xpack.ml.model_repository": "http://localhost:%d"
66+
}
67+
}""", mlModelServer.getPort()));
68+
client().performRequest(request);
69+
}
70+
5971
@Override
6072
protected String getTestRestCluster() {
6173
return cluster.getHttpAddresses();

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MlModelServer.java

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@
1313
import org.apache.http.HttpStatus;
1414
import org.elasticsearch.logging.LogManager;
1515
import org.elasticsearch.logging.Logger;
16+
import org.junit.rules.TestRule;
17+
import org.junit.runner.Description;
18+
import org.junit.runners.model.Statement;
1619

1720
import java.io.IOException;
1821
import java.io.InputStream;
1922
import java.io.OutputStream;
2023
import java.net.InetSocketAddress;
24+
import java.util.Random;
2125
import java.util.concurrent.ExecutorService;
2226
import java.util.concurrent.Executors;
2327

@@ -27,56 +31,17 @@
2731
* If the file is found, its content is returned, otherwise 404.
2832
* Respects a range header to serve partial content.
2933
*/
30-
public class MlModelServer {
34+
public class MlModelServer implements TestRule {
3135

3236
private static final Logger logger = LogManager.getLogger(MlModelServer.class);
3337

34-
private final int port;
35-
private final HttpServer server;
38+
private int port;
3639

37-
private ExecutorService executor;
38-
39-
public MlModelServer() {
40-
try {
41-
server = HttpServer.create();
42-
} catch (IOException e) {
43-
throw new RuntimeException("Could not create server", e);
44-
}
45-
server.createContext("/", this::handle);
46-
port = findUnusedPort();
47-
}
48-
49-
private int findUnusedPort() {
50-
Exception exception = null;
51-
for (int port = 10000; port < 11000; port++) {
52-
try {
53-
server.bind(new InetSocketAddress(port), 0);
54-
return port;
55-
} catch (IOException e) {
56-
exception = e;
57-
}
58-
}
59-
throw new RuntimeException("Could not find port", exception);
60-
}
61-
62-
public int getPort() {
40+
int getPort() {
6341
return port;
6442
}
6543

66-
public void start() {
67-
logger.info("Starting ML model server on port {}", port);
68-
executor = Executors.newCachedThreadPool();
69-
server.setExecutor(executor);
70-
server.start();
71-
}
72-
73-
public void stop() {
74-
logger.info("Stopping ML model server in port {}", port);
75-
server.stop(1);
76-
executor.shutdown();
77-
}
78-
79-
private void handle(HttpExchange exchange) throws IOException {
44+
private static void handle(HttpExchange exchange) throws IOException {
8045
String fileName = exchange.getRequestURI().getPath().substring(1);
8146
// If this architecture is requested, serve the default model instead.
8247
fileName = fileName.replace("_linux-x86_64", "");
@@ -118,4 +83,38 @@ private void handle(HttpExchange exchange) throws IOException {
11883
}
11984
}
12085
}
86+
87+
@Override
88+
public Statement apply(Statement statement, Description description) {
89+
return new Statement() {
90+
@Override
91+
public void evaluate() throws Throwable {
92+
logger.info("Starting ML model server");
93+
HttpServer server = HttpServer.create();
94+
server.createContext("/", MlModelServer::handle);
95+
while (true) {
96+
port = new Random().nextInt(10000, 65536);
97+
try {
98+
server.bind(new InetSocketAddress("localhost", port), 1);
99+
} catch (Exception e) {
100+
continue;
101+
}
102+
break;
103+
}
104+
logger.info("Bound ML model server to port {}", port);
105+
106+
ExecutorService executor = Executors.newCachedThreadPool();
107+
server.setExecutor(executor);
108+
server.start();
109+
110+
try {
111+
statement.evaluate();
112+
} finally {
113+
logger.info("Stopping ML model server in port {}", port);
114+
server.stop(1);
115+
executor.shutdown();
116+
}
117+
}
118+
};
119+
}
121120
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,15 @@
1111
import org.elasticsearch.common.Strings;
1212
import org.elasticsearch.inference.TaskType;
1313
import org.elasticsearch.plugins.Platforms;
14-
import org.junit.AfterClass;
15-
import org.junit.BeforeClass;
1614

1715
import java.io.IOException;
1816
import java.util.List;
1917
import java.util.Map;
2018

2119
import static org.hamcrest.Matchers.containsString;
2220

23-
// This test was previously disabled in CI due to the models being too large
24-
// See "https://github.com/elastic/elasticsearch/issues/105198".
2521
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
2622

27-
@BeforeClass
28-
public static void startModelServer() {
29-
mlModelServer.start();
30-
}
31-
32-
@AfterClass
33-
public static void stopModelServer() {
34-
mlModelServer.stop();
35-
}
36-
3723
public void testPutE5Small_withNoModelVariant() {
3824
{
3925
String inferenceEntityId = "testPutE5Small_withNoModelVariant";

0 commit comments

Comments
 (0)