Skip to content

Commit 3cc62d3

Browse files
committed
polish code
1 parent 08517f2 commit 3cc62d3

File tree

3 files changed

+51
-41
lines changed

3 files changed

+51
-41
lines changed

x-pack/plugin/inference/qa/inference-service-tests/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
apply plugin: 'elasticsearch.internal-java-rest-test'
22

33
dependencies {
4+
javaRestTestImplementation project(path: xpackModule('core'))
45
javaRestTestImplementation project(path: xpackModule('inference'))
56
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
67
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,15 @@ public class InferenceBaseRestTest extends ESRestTestCase {
5858

5959
@Before
6060
public void setMlModelRepository() throws IOException {
61+
logger.info("setting ML model repository");
6162
var request = new Request("PUT", "/_cluster/settings");
6263
request.setJsonEntity(Strings.format("""
6364
{
6465
"persistent": {
65-
"xpack.ml.model_repository": "http://localhost:%d"
66+
"xpack.ml.model_repository": "%s"
6667
}
67-
}""", mlModelServer.getPort()));
68-
client().performRequest(request);
68+
}""", mlModelServer.getUrl()));
69+
assertOK(client().performRequest(request));
6970
}
7071

7172
@Override

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MlModelServer.java

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
import com.sun.net.httpserver.HttpExchange;
1111
import com.sun.net.httpserver.HttpServer;
1212

13+
import org.apache.http.HttpHeaders;
1314
import org.apache.http.HttpStatus;
15+
import org.apache.http.client.utils.URIBuilder;
1416
import org.elasticsearch.logging.LogManager;
1517
import org.elasticsearch.logging.Logger;
18+
import org.elasticsearch.test.fixture.HttpHeaderParser;
19+
import org.elasticsearch.xpack.core.XPackSettings;
1620
import org.junit.rules.TestRule;
1721
import org.junit.runner.Description;
1822
import org.junit.runners.model.Statement;
@@ -27,58 +31,62 @@
2731

2832
/**
2933
* Simple model server to serve ML models.
30-
* The URL path corresponds to file name in this class's resources.
34+
* The URL path corresponds to a file name in this class's resources.
3135
* If the file is found, its content is returned, otherwise 404.
3236
* Respects a range header to serve partial content.
3337
*/
3438
public class MlModelServer implements TestRule {
3539

40+
private static final String HOST = "localhost";
3641
private static final Logger logger = LogManager.getLogger(MlModelServer.class);
3742

3843
private int port;
3944

40-
int getPort() {
41-
return port;
45+
public String getUrl() {
46+
return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString();
4247
}
4348

44-
private static void handle(HttpExchange exchange) throws IOException {
49+
private static String getFileName(HttpExchange exchange) {
50+
// Strip the leading slash
4551
String fileName = exchange.getRequestURI().getPath().substring(1);
46-
// If this architecture is requested, serve the default model instead.
47-
fileName = fileName.replace("_linux-x86_64", "");
48-
String range = exchange.getRequestHeaders().getFirst("Range");
49-
Integer rangeFrom = null;
50-
Integer rangeTo = null;
51-
if (range != null) {
52-
assert range.startsWith("bytes=");
53-
assert range.contains("-");
54-
rangeFrom = Integer.parseInt(range.substring("bytes=".length(), range.indexOf('-')));
55-
rangeTo = Integer.parseInt(range.substring(range.indexOf('-') + 1)) + 1;
52+
// If a model specifically optimized for some platform is requested,
53+
// serve the default non-optimized model instead, which is compatible.
54+
for (String platform : XPackSettings.ML_NATIVE_CODE_PLATFORMS) {
55+
fileName = fileName.replace("_" + platform, "");
5656
}
57-
logger.info("Request: {} range=[{},{})", fileName, rangeFrom, rangeTo);
57+
return fileName;
58+
}
59+
60+
private static void handle(HttpExchange exchange) throws IOException {
61+
String fileName = getFileName(exchange);
62+
String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE);
63+
HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null;
64+
logger.info("Request: {} {}", fileName, range == null ? "" : range);
65+
5866
ClassLoader classloader = Thread.currentThread().getContextClassLoader();
5967
try (InputStream is = classloader.getResourceAsStream(fileName)) {
68+
int httpStatus;
69+
long numBytes;
6070
if (is == null) {
61-
logger.info("Response: {} 404", fileName);
62-
exchange.sendResponseHeaders(HttpStatus.SC_NOT_FOUND, 0);
71+
httpStatus = HttpStatus.SC_NOT_FOUND;
72+
numBytes = 0;
6373
} else {
64-
try (OutputStream os = exchange.getResponseBody()) {
65-
int httpStatus;
66-
int numBytes;
67-
if (range == null) {
68-
httpStatus = HttpStatus.SC_OK;
69-
numBytes = is.available();
70-
} else {
71-
httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
72-
is.skipNBytes(rangeFrom);
73-
numBytes = rangeTo - rangeFrom;
74-
}
75-
logger.info("Response: {} {}", fileName, httpStatus);
76-
exchange.sendResponseHeaders(httpStatus, numBytes);
77-
while (numBytes > 0) {
78-
byte[] bytes = is.readNBytes(Math.min(1 << 20, numBytes));
79-
os.write(bytes);
80-
numBytes -= bytes.length;
81-
}
74+
if (range == null) {
75+
httpStatus = HttpStatus.SC_OK;
76+
numBytes = is.available();
77+
} else {
78+
httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
79+
is.skipNBytes(range.start());
80+
numBytes = range.end() - range.start() + 1;
81+
}
82+
}
83+
logger.info("Response: {} {}", fileName, httpStatus);
84+
exchange.sendResponseHeaders(httpStatus, numBytes);
85+
try (OutputStream os = exchange.getResponseBody()) {
86+
while (numBytes > 0) {
87+
byte[] bytes = is.readNBytes((int) Math.min(1 << 20, numBytes));
88+
os.write(bytes);
89+
numBytes -= bytes.length;
8290
}
8391
}
8492
}
@@ -91,11 +99,10 @@ public Statement apply(Statement statement, Description description) {
9199
public void evaluate() throws Throwable {
92100
logger.info("Starting ML model server");
93101
HttpServer server = HttpServer.create();
94-
server.createContext("/", MlModelServer::handle);
95102
while (true) {
96103
port = new Random().nextInt(10000, 65536);
97104
try {
98-
server.bind(new InetSocketAddress("localhost", port), 1);
105+
server.bind(new InetSocketAddress(HOST, port), 1);
99106
} catch (Exception e) {
100107
continue;
101108
}
@@ -105,12 +112,13 @@ public void evaluate() throws Throwable {
105112

106113
ExecutorService executor = Executors.newCachedThreadPool();
107114
server.setExecutor(executor);
115+
server.createContext("/", MlModelServer::handle);
108116
server.start();
109117

110118
try {
111119
statement.evaluate();
112120
} finally {
113-
logger.info("Stopping ML model server in port {}", port);
121+
logger.info("Stopping ML model server on port {}", port);
114122
server.stop(1);
115123
executor.shutdown();
116124
}

0 commit comments

Comments
 (0)