|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | 
|  | 3 | + * or more contributor license agreements. Licensed under the Elastic License | 
|  | 4 | + * 2.0; you may not use this file except in compliance with the Elastic License | 
|  | 5 | + * 2.0. | 
|  | 6 | + */ | 
|  | 7 | + | 
|  | 8 | +package org.elasticsearch.xpack.inference; | 
|  | 9 | + | 
|  | 10 | +import com.sun.net.httpserver.HttpExchange; | 
|  | 11 | +import com.sun.net.httpserver.HttpServer; | 
|  | 12 | + | 
|  | 13 | +import org.apache.http.HttpHeaders; | 
|  | 14 | +import org.apache.http.HttpStatus; | 
|  | 15 | +import org.apache.http.client.utils.URIBuilder; | 
|  | 16 | +import org.elasticsearch.logging.LogManager; | 
|  | 17 | +import org.elasticsearch.logging.Logger; | 
|  | 18 | +import org.elasticsearch.test.fixture.HttpHeaderParser; | 
|  | 19 | +import org.elasticsearch.xcontent.XContentParser; | 
|  | 20 | +import org.elasticsearch.xcontent.XContentParserConfiguration; | 
|  | 21 | +import org.elasticsearch.xcontent.XContentType; | 
|  | 22 | +import org.elasticsearch.xpack.core.XPackSettings; | 
|  | 23 | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; | 
|  | 24 | +import org.junit.rules.TestRule; | 
|  | 25 | +import org.junit.runner.Description; | 
|  | 26 | +import org.junit.runners.model.Statement; | 
|  | 27 | + | 
|  | 28 | +import java.io.ByteArrayInputStream; | 
|  | 29 | +import java.io.IOException; | 
|  | 30 | +import java.io.InputStream; | 
|  | 31 | +import java.io.OutputStream; | 
|  | 32 | +import java.net.InetSocketAddress; | 
|  | 33 | +import java.nio.charset.StandardCharsets; | 
|  | 34 | +import java.util.Random; | 
|  | 35 | +import java.util.concurrent.ExecutorService; | 
|  | 36 | +import java.util.concurrent.Executors; | 
|  | 37 | + | 
|  | 38 | +/** | 
|  | 39 | + * Simple model server to serve ML models. | 
|  | 40 | + * The URL path corresponds to a file name in this class's resources. | 
|  | 41 | + * If the file is found, its content is returned, otherwise 404. | 
|  | 42 | + * Respects a range header to serve partial content. | 
|  | 43 | + */ | 
|  | 44 | +public class MlModelServer implements TestRule { | 
|  | 45 | + | 
|  | 46 | +    private static final String HOST = "localhost"; | 
|  | 47 | +    private static final Logger logger = LogManager.getLogger(MlModelServer.class); | 
|  | 48 | + | 
|  | 49 | +    private int port; | 
|  | 50 | + | 
|  | 51 | +    public String getUrl() { | 
|  | 52 | +        return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString(); | 
|  | 53 | +    } | 
|  | 54 | + | 
|  | 55 | +    private void handle(HttpExchange exchange) throws IOException { | 
|  | 56 | +        String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE); | 
|  | 57 | +        HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null; | 
|  | 58 | +        logger.info("request: {} range={}", exchange.getRequestURI().getPath(), range); | 
|  | 59 | + | 
|  | 60 | +        try (InputStream is = getInputStream(exchange)) { | 
|  | 61 | +            int httpStatus; | 
|  | 62 | +            long numBytes; | 
|  | 63 | +            if (is == null) { | 
|  | 64 | +                httpStatus = HttpStatus.SC_NOT_FOUND; | 
|  | 65 | +                numBytes = 0; | 
|  | 66 | +            } else if (range == null) { | 
|  | 67 | +                httpStatus = HttpStatus.SC_OK; | 
|  | 68 | +                numBytes = is.available(); | 
|  | 69 | +            } else { | 
|  | 70 | +                httpStatus = HttpStatus.SC_PARTIAL_CONTENT; | 
|  | 71 | +                is.skipNBytes(range.start()); | 
|  | 72 | +                numBytes = range.end() - range.start() + 1; | 
|  | 73 | +            } | 
|  | 74 | +            logger.info("response: {} {}", exchange.getRequestURI().getPath(), httpStatus); | 
|  | 75 | +            exchange.sendResponseHeaders(httpStatus, numBytes); | 
|  | 76 | +            try (OutputStream os = exchange.getResponseBody()) { | 
|  | 77 | +                while (numBytes > 0) { | 
|  | 78 | +                    byte[] bytes = is.readNBytes((int) Math.min(1 << 20, numBytes)); | 
|  | 79 | +                    os.write(bytes); | 
|  | 80 | +                    numBytes -= bytes.length; | 
|  | 81 | +                } | 
|  | 82 | +            } | 
|  | 83 | +        } | 
|  | 84 | +    } | 
|  | 85 | + | 
|  | 86 | +    private InputStream getInputStream(HttpExchange exchange) throws IOException { | 
|  | 87 | +        String path = exchange.getRequestURI().getPath().substring(1);  // Strip leading slash | 
|  | 88 | +        String modelId = path.substring(0, path.indexOf('.')); | 
|  | 89 | +        String extension = path.substring(path.indexOf('.') + 1); | 
|  | 90 | + | 
|  | 91 | +        // If a model specifically optimized for some platform is requested, | 
|  | 92 | +        // serve the default non-optimized model instead, which is compatible. | 
|  | 93 | +        String defaultModelId = modelId; | 
|  | 94 | +        for (String platform : XPackSettings.ML_NATIVE_CODE_PLATFORMS) { | 
|  | 95 | +            defaultModelId = defaultModelId.replace("_" + platform, ""); | 
|  | 96 | +        } | 
|  | 97 | + | 
|  | 98 | +        ClassLoader classloader = Thread.currentThread().getContextClassLoader(); | 
|  | 99 | +        InputStream is = classloader.getResourceAsStream(defaultModelId + "." + extension); | 
|  | 100 | +        if (is != null && modelId.equals(defaultModelId) == false && extension.equals("metadata.json")) { | 
|  | 101 | +            // When an optimized version is requested, fix the default metadata, | 
|  | 102 | +            // so that it contains the correct model ID. | 
|  | 103 | +            try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, is.readAllBytes())) { | 
|  | 104 | +                is.close(); | 
|  | 105 | +                ModelPackageConfig packageConfig = ModelPackageConfig.fromXContentLenient(parser); | 
|  | 106 | +                packageConfig = new ModelPackageConfig.Builder(packageConfig).setPackedModelId(modelId).build(); | 
|  | 107 | +                is = new ByteArrayInputStream(packageConfig.toString().getBytes(StandardCharsets.UTF_8)); | 
|  | 108 | +            } | 
|  | 109 | +        } | 
|  | 110 | +        return is; | 
|  | 111 | +    } | 
|  | 112 | + | 
|  | 113 | +    @Override | 
|  | 114 | +    public Statement apply(Statement statement, Description description) { | 
|  | 115 | +        return new Statement() { | 
|  | 116 | +            @Override | 
|  | 117 | +            public void evaluate() throws Throwable { | 
|  | 118 | +                logger.info("Starting ML model server"); | 
|  | 119 | +                HttpServer server = HttpServer.create(); | 
|  | 120 | +                while (true) { | 
|  | 121 | +                    port = new Random().nextInt(10000, 65536); | 
|  | 122 | +                    try { | 
|  | 123 | +                        server.bind(new InetSocketAddress(HOST, port), 1); | 
|  | 124 | +                    } catch (Exception e) { | 
|  | 125 | +                        continue; | 
|  | 126 | +                    } | 
|  | 127 | +                    break; | 
|  | 128 | +                } | 
|  | 129 | +                logger.info("Bound ML model server to port {}", port); | 
|  | 130 | + | 
|  | 131 | +                ExecutorService executor = Executors.newCachedThreadPool(); | 
|  | 132 | +                server.setExecutor(executor); | 
|  | 133 | +                server.createContext("/", MlModelServer.this::handle); | 
|  | 134 | +                server.start(); | 
|  | 135 | + | 
|  | 136 | +                try { | 
|  | 137 | +                    statement.evaluate(); | 
|  | 138 | +                } finally { | 
|  | 139 | +                    logger.info("Stopping ML model server on port {}", port); | 
|  | 140 | +                    server.stop(1); | 
|  | 141 | +                    executor.shutdown(); | 
|  | 142 | +                } | 
|  | 143 | +            } | 
|  | 144 | +        }; | 
|  | 145 | +    } | 
|  | 146 | +} | 
0 commit comments