1616import org .elasticsearch .logging .LogManager ;
1717import org .elasticsearch .logging .Logger ;
1818import org .elasticsearch .test .fixture .HttpHeaderParser ;
19+ import org .elasticsearch .xcontent .XContentParser ;
20+ import org .elasticsearch .xcontent .XContentParserConfiguration ;
21+ import org .elasticsearch .xcontent .XContentType ;
1922import org .elasticsearch .xpack .core .XPackSettings ;
23+ import org .elasticsearch .xpack .core .ml .inference .trainedmodel .ModelPackageConfig ;
2024import org .junit .rules .TestRule ;
2125import org .junit .runner .Description ;
2226import org .junit .runners .model .Statement ;
2327
28+ import java .io .ByteArrayInputStream ;
2429import java .io .IOException ;
2530import java .io .InputStream ;
2631import java .io .OutputStream ;
2732import java .net .InetSocketAddress ;
33+ import java .nio .charset .StandardCharsets ;
2834import java .util .Random ;
2935import java .util .concurrent .ExecutorService ;
3036import java .util .concurrent .Executors ;
@@ -46,41 +52,26 @@ public String getUrl() {
4652 return new URIBuilder ().setScheme ("http" ).setHost (HOST ).setPort (port ).toString ();
4753 }
4854
49- private static String getFileName (HttpExchange exchange ) {
50- // Strip the leading slash
51- String fileName = exchange .getRequestURI ().getPath ().substring (1 );
52- // If a model specifically optimized for some platform is requested,
53- // serve the default non-optimized model instead, which is compatible.
54- for (String platform : XPackSettings .ML_NATIVE_CODE_PLATFORMS ) {
55- fileName = fileName .replace ("_" + platform , "" );
56- }
57- return fileName ;
58- }
59-
60- private static void handle (HttpExchange exchange ) throws IOException {
61- String fileName = getFileName (exchange );
55+ private void handle (HttpExchange exchange ) throws IOException {
6256 String rangeHeader = exchange .getRequestHeaders ().getFirst (HttpHeaders .RANGE );
6357 HttpHeaderParser .Range range = rangeHeader != null ? HttpHeaderParser .parseRangeHeader (rangeHeader ) : null ;
64- logger .info ("Request : {} {}" , fileName , range == null ? "" : range );
58+ logger .info ("request : {} range= {}" , exchange . getRequestURI (). getPath (), range );
6559
66- ClassLoader classloader = Thread .currentThread ().getContextClassLoader ();
67- try (InputStream is = classloader .getResourceAsStream (fileName )) {
60+ try (InputStream is = getInputStream (exchange )) {
6861 int httpStatus ;
6962 long numBytes ;
7063 if (is == null ) {
7164 httpStatus = HttpStatus .SC_NOT_FOUND ;
7265 numBytes = 0 ;
66+ } else if (range == null ) {
67+ httpStatus = HttpStatus .SC_OK ;
68+ numBytes = is .available ();
7369 } else {
74- if (range == null ) {
75- httpStatus = HttpStatus .SC_OK ;
76- numBytes = is .available ();
77- } else {
78- httpStatus = HttpStatus .SC_PARTIAL_CONTENT ;
79- is .skipNBytes (range .start ());
80- numBytes = range .end () - range .start () + 1 ;
81- }
70+ httpStatus = HttpStatus .SC_PARTIAL_CONTENT ;
71+ is .skipNBytes (range .start ());
72+ numBytes = range .end () - range .start () + 1 ;
8273 }
83- logger .info ("Response : {} {}" , fileName , httpStatus );
74+ logger .info ("response : {} {}" , exchange . getRequestURI (). getPath () , httpStatus );
8475 exchange .sendResponseHeaders (httpStatus , numBytes );
8576 try (OutputStream os = exchange .getResponseBody ()) {
8677 while (numBytes > 0 ) {
@@ -92,6 +83,33 @@ private static void handle(HttpExchange exchange) throws IOException {
9283 }
9384 }
9485
86+ private InputStream getInputStream (HttpExchange exchange ) throws IOException {
87+ String path = exchange .getRequestURI ().getPath ().substring (1 ); // Strip leading slash
88+ String modelId = path .substring (0 , path .indexOf ('.' ));
89+ String extension = path .substring (path .indexOf ('.' ) + 1 );
90+
91+ // If a model specifically optimized for some platform is requested,
92+ // serve the default non-optimized model instead, which is compatible.
93+ String defaultModelId = modelId ;
94+ for (String platform : XPackSettings .ML_NATIVE_CODE_PLATFORMS ) {
95+ defaultModelId = defaultModelId .replace ("_" + platform , "" );
96+ }
97+
98+ ClassLoader classloader = Thread .currentThread ().getContextClassLoader ();
99+ InputStream is = classloader .getResourceAsStream (defaultModelId + "." + extension );
100+ if (is != null && modelId .equals (defaultModelId ) == false && extension .equals ("metadata.json" )) {
101+ // When an optimized version is requested, fix the default metadata,
102+ // so that it contains the correct model ID.
103+ try (XContentParser parser = XContentType .JSON .xContent ().createParser (XContentParserConfiguration .EMPTY , is .readAllBytes ())) {
104+ is .close ();
105+ ModelPackageConfig packageConfig = ModelPackageConfig .fromXContentLenient (parser );
106+ packageConfig = new ModelPackageConfig .Builder (packageConfig ).setPackedModelId (modelId ).build ();
107+ is = new ByteArrayInputStream (packageConfig .toString ().getBytes (StandardCharsets .UTF_8 ));
108+ }
109+ }
110+ return is ;
111+ }
112+
95113 @ Override
96114 public Statement apply (Statement statement , Description description ) {
97115 return new Statement () {
@@ -112,7 +130,7 @@ public void evaluate() throws Throwable {
112130
113131 ExecutorService executor = Executors .newCachedThreadPool ();
114132 server .setExecutor (executor );
115- server .createContext ("/" , MlModelServer ::handle );
133+ server .createContext ("/" , MlModelServer . this ::handle );
116134 server .start ();
117135
118136 try {
0 commit comments