Skip to content

Commit ca60b55

Browse files
committed
fix metadata for optimized models
1 parent 10cb533 commit ca60b55

File tree

1 file changed

+44
-26
lines changed
  • x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference

1 file changed

+44
-26
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MlModelServer.java

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,21 @@
1616
import org.elasticsearch.logging.LogManager;
1717
import org.elasticsearch.logging.Logger;
1818
import org.elasticsearch.test.fixture.HttpHeaderParser;
19+
import org.elasticsearch.xcontent.XContentParser;
20+
import org.elasticsearch.xcontent.XContentParserConfiguration;
21+
import org.elasticsearch.xcontent.XContentType;
1922
import org.elasticsearch.xpack.core.XPackSettings;
23+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
2024
import org.junit.rules.TestRule;
2125
import org.junit.runner.Description;
2226
import org.junit.runners.model.Statement;
2327

28+
import java.io.ByteArrayInputStream;
2429
import java.io.IOException;
2530
import java.io.InputStream;
2631
import java.io.OutputStream;
2732
import java.net.InetSocketAddress;
33+
import java.nio.charset.StandardCharsets;
2834
import java.util.Random;
2935
import java.util.concurrent.ExecutorService;
3036
import java.util.concurrent.Executors;
@@ -46,41 +52,26 @@ public String getUrl() {
4652
return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString();
4753
}
4854

49-
private static String getFileName(HttpExchange exchange) {
50-
// Strip the leading slash
51-
String fileName = exchange.getRequestURI().getPath().substring(1);
52-
// If a model specifically optimized for some platform is requested,
53-
// serve the default non-optimized model instead, which is compatible.
54-
for (String platform : XPackSettings.ML_NATIVE_CODE_PLATFORMS) {
55-
fileName = fileName.replace("_" + platform, "");
56-
}
57-
return fileName;
58-
}
59-
60-
private static void handle(HttpExchange exchange) throws IOException {
61-
String fileName = getFileName(exchange);
55+
private void handle(HttpExchange exchange) throws IOException {
6256
String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE);
6357
HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null;
64-
logger.info("Request: {} {}", fileName, range == null ? "" : range);
58+
logger.info("request: {} range={}", exchange.getRequestURI().getPath(), range);
6559

66-
ClassLoader classloader = Thread.currentThread().getContextClassLoader();
67-
try (InputStream is = classloader.getResourceAsStream(fileName)) {
60+
try (InputStream is = getInputStream(exchange)) {
6861
int httpStatus;
6962
long numBytes;
7063
if (is == null) {
7164
httpStatus = HttpStatus.SC_NOT_FOUND;
7265
numBytes = 0;
66+
} else if (range == null) {
67+
httpStatus = HttpStatus.SC_OK;
68+
numBytes = is.available();
7369
} else {
74-
if (range == null) {
75-
httpStatus = HttpStatus.SC_OK;
76-
numBytes = is.available();
77-
} else {
78-
httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
79-
is.skipNBytes(range.start());
80-
numBytes = range.end() - range.start() + 1;
81-
}
70+
httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
71+
is.skipNBytes(range.start());
72+
numBytes = range.end() - range.start() + 1;
8273
}
83-
logger.info("Response: {} {}", fileName, httpStatus);
74+
logger.info("response: {} {}", exchange.getRequestURI().getPath(), httpStatus);
8475
exchange.sendResponseHeaders(httpStatus, numBytes);
8576
try (OutputStream os = exchange.getResponseBody()) {
8677
while (numBytes > 0) {
@@ -92,6 +83,33 @@ private static void handle(HttpExchange exchange) throws IOException {
9283
}
9384
}
9485

86+
private InputStream getInputStream(HttpExchange exchange) throws IOException {
87+
String path = exchange.getRequestURI().getPath().substring(1); // Strip leading slash
88+
String modelId = path.substring(0, path.indexOf('.'));
89+
String extension = path.substring(path.indexOf('.') + 1);
90+
91+
// If a model specifically optimized for some platform is requested,
92+
// serve the default non-optimized model instead, which is compatible.
93+
String defaultModelId = modelId;
94+
for (String platform : XPackSettings.ML_NATIVE_CODE_PLATFORMS) {
95+
defaultModelId = defaultModelId.replace("_" + platform, "");
96+
}
97+
98+
ClassLoader classloader = Thread.currentThread().getContextClassLoader();
99+
InputStream is = classloader.getResourceAsStream(defaultModelId + "." + extension);
100+
if (is != null && modelId.equals(defaultModelId) == false && extension.equals("metadata.json")) {
101+
// When an optimized version is requested, fix the default metadata,
102+
// so that it contains the correct model ID.
103+
try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, is.readAllBytes())) {
104+
is.close();
105+
ModelPackageConfig packageConfig = ModelPackageConfig.fromXContentLenient(parser);
106+
packageConfig = new ModelPackageConfig.Builder(packageConfig).setPackedModelId(modelId).build();
107+
is = new ByteArrayInputStream(packageConfig.toString().getBytes(StandardCharsets.UTF_8));
108+
}
109+
}
110+
return is;
111+
}
112+
95113
@Override
96114
public Statement apply(Statement statement, Description description) {
97115
return new Statement() {
@@ -112,7 +130,7 @@ public void evaluate() throws Throwable {
112130

113131
ExecutorService executor = Executors.newCachedThreadPool();
114132
server.setExecutor(executor);
115-
server.createContext("/", MlModelServer::handle);
133+
server.createContext("/", MlModelServer.this::handle);
116134
server.start();
117135

118136
try {

0 commit comments

Comments
 (0)