1010import com .sun .net .httpserver .HttpExchange ;
1111import com .sun .net .httpserver .HttpServer ;
1212
13+ import org .apache .http .HttpHeaders ;
1314import org .apache .http .HttpStatus ;
15+ import org .apache .http .client .utils .URIBuilder ;
1416import org .elasticsearch .logging .LogManager ;
1517import org .elasticsearch .logging .Logger ;
18+ import org .elasticsearch .test .fixture .HttpHeaderParser ;
19+ import org .elasticsearch .xpack .core .XPackSettings ;
1620import org .junit .rules .TestRule ;
1721import org .junit .runner .Description ;
1822import org .junit .runners .model .Statement ;
2731
2832/**
2933 * Simple model server to serve ML models.
30- * The URL path corresponds to file name in this class's resources.
34+ * The URL path corresponds to a file name in this class's resources.
3135 * If the file is found, its content is returned, otherwise 404.
3236 * Respects a range header to serve partial content.
3337 */
3438public class MlModelServer implements TestRule {
3539
40+ private static final String HOST = "localhost" ;
3641 private static final Logger logger = LogManager .getLogger (MlModelServer .class );
3742
3843 private int port ;
3944
40- int getPort () {
41- return port ;
45+ public String getUrl () {
46+ return new URIBuilder (). setScheme ( "http" ). setHost ( HOST ). setPort ( port ). toString () ;
4247 }
4348
44- private static void handle (HttpExchange exchange ) throws IOException {
49+ private static String getFileName (HttpExchange exchange ) {
50+ // Strip the leading slash
4551 String fileName = exchange .getRequestURI ().getPath ().substring (1 );
46- // If this architecture is requested, serve the default model instead.
47- fileName = fileName .replace ("_linux-x86_64" , "" );
48- String range = exchange .getRequestHeaders ().getFirst ("Range" );
49- Integer rangeFrom = null ;
50- Integer rangeTo = null ;
51- if (range != null ) {
52- assert range .startsWith ("bytes=" );
53- assert range .contains ("-" );
54- rangeFrom = Integer .parseInt (range .substring ("bytes=" .length (), range .indexOf ('-' )));
55- rangeTo = Integer .parseInt (range .substring (range .indexOf ('-' ) + 1 )) + 1 ;
52+ // If a model specifically optimized for some platform is requested,
53+ // serve the default non-optimized model instead, which is compatible.
54+ for (String platform : XPackSettings .ML_NATIVE_CODE_PLATFORMS ) {
55+ fileName = fileName .replace ("_" + platform , "" );
5656 }
57- logger .info ("Request: {} range=[{},{})" , fileName , rangeFrom , rangeTo );
57+ return fileName ;
58+ }
59+
60+ private static void handle (HttpExchange exchange ) throws IOException {
61+ String fileName = getFileName (exchange );
62+ String rangeHeader = exchange .getRequestHeaders ().getFirst (HttpHeaders .RANGE );
63+ HttpHeaderParser .Range range = rangeHeader != null ? HttpHeaderParser .parseRangeHeader (rangeHeader ) : null ;
64+ logger .info ("Request: {} {}" , fileName , range == null ? "" : range );
65+
5866 ClassLoader classloader = Thread .currentThread ().getContextClassLoader ();
5967 try (InputStream is = classloader .getResourceAsStream (fileName )) {
68+ int httpStatus ;
69+ long numBytes ;
6070 if (is == null ) {
61- logger . info ( "Response: {} 404" , fileName ) ;
62- exchange . sendResponseHeaders ( HttpStatus . SC_NOT_FOUND , 0 ) ;
71+ httpStatus = HttpStatus . SC_NOT_FOUND ;
72+ numBytes = 0 ;
6373 } else {
64- try (OutputStream os = exchange .getResponseBody ()) {
65- int httpStatus ;
66- int numBytes ;
67- if (range == null ) {
68- httpStatus = HttpStatus .SC_OK ;
69- numBytes = is .available ();
70- } else {
71- httpStatus = HttpStatus .SC_PARTIAL_CONTENT ;
72- is .skipNBytes (rangeFrom );
73- numBytes = rangeTo - rangeFrom ;
74- }
75- logger .info ("Response: {} {}" , fileName , httpStatus );
76- exchange .sendResponseHeaders (httpStatus , numBytes );
77- while (numBytes > 0 ) {
78- byte [] bytes = is .readNBytes (Math .min (1 << 20 , numBytes ));
79- os .write (bytes );
80- numBytes -= bytes .length ;
81- }
74+ if (range == null ) {
75+ httpStatus = HttpStatus .SC_OK ;
76+ numBytes = is .available ();
77+ } else {
78+ httpStatus = HttpStatus .SC_PARTIAL_CONTENT ;
79+ is .skipNBytes (range .start ());
80+ numBytes = range .end () - range .start () + 1 ;
81+ }
82+ }
83+ logger .info ("Response: {} {}" , fileName , httpStatus );
84+ exchange .sendResponseHeaders (httpStatus , numBytes );
85+ try (OutputStream os = exchange .getResponseBody ()) {
86+ while (numBytes > 0 ) {
87+ byte [] bytes = is .readNBytes ((int ) Math .min (1 << 20 , numBytes ));
88+ os .write (bytes );
89+ numBytes -= bytes .length ;
8290 }
8391 }
8492 }
@@ -91,11 +99,10 @@ public Statement apply(Statement statement, Description description) {
9199 public void evaluate () throws Throwable {
92100 logger .info ("Starting ML model server" );
93101 HttpServer server = HttpServer .create ();
94- server .createContext ("/" , MlModelServer ::handle );
95102 while (true ) {
96103 port = new Random ().nextInt (10000 , 65536 );
97104 try {
98- server .bind (new InetSocketAddress ("localhost" , port ), 1 );
105+ server .bind (new InetSocketAddress (HOST , port ), 1 );
99106 } catch (Exception e ) {
100107 continue ;
101108 }
@@ -105,12 +112,13 @@ public void evaluate() throws Throwable {
105112
106113 ExecutorService executor = Executors .newCachedThreadPool ();
107114 server .setExecutor (executor );
115+ server .createContext ("/" , MlModelServer ::handle );
108116 server .start ();
109117
110118 try {
111119 statement .evaluate ();
112120 } finally {
113- logger .info ("Stopping ML model server in port {}" , port );
121+ logger .info ("Stopping ML model server on port {}" , port );
114122 server .stop (1 );
115123 executor .shutdown ();
116124 }
0 commit comments