diff --git a/docs/changelog/111684.yaml b/docs/changelog/111684.yaml new file mode 100644 index 0000000000000..32edb5723cb0a --- /dev/null +++ b/docs/changelog/111684.yaml @@ -0,0 +1,5 @@ +pr: 111684 +summary: Write downloaded model parts async +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java index 7fb47e901f703..6c15b42dc65d5 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference; -import org.apache.lucene.tests.util.LuceneTestCase; import org.elasticsearch.client.Request; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; @@ -19,11 +18,11 @@ import static org.hamcrest.Matchers.containsString; -// Tests disabled in CI due to the models being too large to download. Can be enabled (commented out) for local testing -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198") +// This test was previously disabled in CI due to the models being too large +// See "https://github.com/elastic/elasticsearch/issues/105198". public class TextEmbeddingCrudIT extends InferenceBaseRestTest { - public void testPutE5Small_withNoModelVariant() throws IOException { + public void testPutE5Small_withNoModelVariant() { { String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); expectThrows( @@ -51,6 +50,7 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException { deleteTextEmbeddingModel(inferenceEntityId); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198") public void testPutE5Small_withPlatformSpecificVariant() throws IOException { String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) { @@ -124,7 +124,7 @@ private String noModelIdVariantJsonEntity() { private String platformAgnosticModelVariantJsonEntity() { return """ { - "service": "text_embedding", + "service": "elasticsearch", "service_settings": { "num_allocations": 1, "num_threads": 1, @@ -137,7 +137,7 @@ private String platformAgnosticModelVariantJsonEntity() { private String platformSpecificModelVariantJsonEntity() { return """ { - "service": "text_embedding", + "service": "elasticsearch", "service_settings": { "num_allocations": 1, "num_threads": 1, @@ -150,7 +150,7 @@ private String platformSpecificModelVariantJsonEntity() { private String fakeModelVariantJsonEntity() { return """ { - "service": "text_embedding", + "service": "elasticsearch", "service_settings": { "num_allocations": 1, "num_threads": 1, diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java index e927c46e6bd29..a63d911e9d40d 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java @@ -15,12 +15,17 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ExecutorBuilder; +import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask; +import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter; import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction; import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage; @@ -44,9 +49,6 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin Setting.Property.Dynamic ); - // re-using thread pool setup by the ml plugin - public static final String UTILITY_THREAD_POOL_NAME = "ml_utility"; - // This link will be invalid for serverless, but serverless will never be // air-gapped, so this message should never be needed. private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format( @@ -54,6 +56,8 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1") ); + public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download"; + public MachineLearningPackageLoader() {} @Override @@ -81,6 +85,24 @@ public List getNamedWriteables() { ); } + @Override + public List> getExecutorBuilders(Settings settings) { + return List.of(modelDownloadExecutor(settings)); + } + + public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) { + // Threadpool with a fixed number of threads for + // downloading the model definition files + return new FixedExecutorBuilder( + settings, + MODEL_DOWNLOAD_THREADPOOL_NAME, + ModelImporter.NUMBER_OF_STREAMS, + -1, // unbounded queue size + "xpack.ml.model_download_thread_pool", + EsExecutors.TaskTrackingConfig.DO_NOT_TRACK + ); + } + @Override public List getBootstrapChecks() { return List.of(new BootstrapCheck() { diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java index 33d5d5982d2b0..b155d6c73ccef 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java @@ -10,124 +10,265 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.Strings; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; +import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; -import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.core.Strings.format; /** - * A helper class for abstracting out the use of the ModelLoaderUtils to make dependency injection testing easier. + * For downloading and the vocabulary and model definition file and + * indexing those files in Elasticsearch. + * Holding the large model definition file in memory will consume + * too much memory, instead it is streamed in chunks and each chunk + * written to the index in a non-blocking request. + * The model files may be installed from a local file or download + * from a server. The server download uses {@link #NUMBER_OF_STREAMS} + * connections each using the Range header to split the stream by byte + * range. There is a complication in that the final part of the model + * definition must be uploaded last as writing this part causes an index + * refresh. + * When read from file a single thread is used to read the file + * stream, split into chunks and index those chunks. */ -class ModelImporter { +public class ModelImporter { private static final int DEFAULT_CHUNK_SIZE = 1024 * 1024; // 1MB + public static final int NUMBER_OF_STREAMS = 5; private static final Logger logger = LogManager.getLogger(ModelImporter.class); private final Client client; private final String modelId; private final ModelPackageConfig config; private final ModelDownloadTask task; + private final ExecutorService executorService; + private final AtomicInteger progressCounter = new AtomicInteger(); + private final URI uri; + private final CircuitBreakerService breakerService; - ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task) { + ModelImporter( + Client client, + String modelId, + ModelPackageConfig packageConfig, + ModelDownloadTask task, + ThreadPool threadPool, + CircuitBreakerService cbs + ) throws URISyntaxException { this.client = client; this.modelId = Objects.requireNonNull(modelId); this.config = Objects.requireNonNull(packageConfig); this.task = Objects.requireNonNull(task); + this.executorService = threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME); + this.uri = ModelLoaderUtils.resolvePackageLocation( + config.getModelRepository(), + config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION + ); + this.breakerService = cbs; } - public void doImport() throws URISyntaxException, IOException, ElasticsearchStatusException { - long size = config.getSize(); + public void doImport(ActionListener listener) { + executorService.execute(() -> doImportInternal(listener)); + } - // Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the - // download is complete - if (Strings.isNullOrEmpty(config.getVocabularyFile()) == false) { - uploadVocabulary(); + private void doImportInternal(ActionListener finalListener) { + assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) + : format( + "Model download must execute from [%s] but thread is [%s]", + MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, + Thread.currentThread().getName() + ); - logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile())); - } + ModelLoaderUtils.VocabularyParts vocabularyParts = null; + try { + if (config.getVocabularyFile() != null) { + vocabularyParts = ModelLoaderUtils.loadVocabulary( + ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile()) + ); + } - URI uri = ModelLoaderUtils.resolvePackageLocation( - config.getModelRepository(), - config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION - ); + // simple round up + int totalParts = (int) ((config.getSize() + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE); - InputStream modelInputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri); + if (ModelLoaderUtils.uriIsFile(uri) == false) { + breakerService.getBreaker(CircuitBreaker.REQUEST) + .addEstimateBytesAndMaybeBreak(DEFAULT_CHUNK_SIZE * NUMBER_OF_STREAMS, "model importer"); + var breakerFreeingListener = ActionListener.runAfter( + finalListener, + () -> breakerService.getBreaker(CircuitBreaker.REQUEST).addWithoutBreaking(-(DEFAULT_CHUNK_SIZE * NUMBER_OF_STREAMS)) + ); - ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker(modelInputStream, DEFAULT_CHUNK_SIZE); + var ranges = ModelLoaderUtils.split(config.getSize(), NUMBER_OF_STREAMS, DEFAULT_CHUNK_SIZE); + var downloaders = new ArrayList(ranges.size()); + for (var range : ranges) { + downloaders.add(new ModelLoaderUtils.HttpStreamChunker(uri, range, DEFAULT_CHUNK_SIZE)); + } + downloadModelDefinition(config.getSize(), totalParts, vocabularyParts, downloaders, breakerFreeingListener); + } else { + InputStream modelInputStream = ModelLoaderUtils.getFileInputStream(uri); + ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker( + modelInputStream, + DEFAULT_CHUNK_SIZE + ); + readModelDefinitionFromFile(config.getSize(), totalParts, chunkIterator, vocabularyParts, finalListener); + } + } catch (Exception e) { + finalListener.onFailure(e); + } + } - // simple round up - int totalParts = (int) ((size + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE); + void downloadModelDefinition( + long size, + int totalParts, + @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, + List downloaders, + ActionListener finalListener + ) { + try (var countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> executorService.execute(() -> { + var finalDownloader = downloaders.get(downloaders.size() - 1); + downloadFinalPart(size, totalParts, finalDownloader, finalListener.delegateFailureAndWrap((l, r) -> { + checkDownloadComplete(downloaders); + l.onResponse(AcknowledgedResponse.TRUE); + })); + }), finalListener::onFailure))) { + // Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the + // download is complete + if (vocabularyParts != null) { + uploadVocabulary(vocabularyParts, countingListener); + } - for (int part = 0; part < totalParts - 1; ++part) { - task.setProgress(totalParts, part); - BytesArray definition = chunkIterator.next(); + // Download all but the final split. + // The final split is a single chunk + for (int streamSplit = 0; streamSplit < downloaders.size() - 1; ++streamSplit) { + final var downloader = downloaders.get(streamSplit); + var rangeDownloadedListener = countingListener.acquire(); // acquire to keep the counting listener from closing + executorService.execute( + () -> downloadPartInRange(size, totalParts, downloader, executorService, countingListener, rangeDownloadedListener) + ); + } + } + } - PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request( - modelId, - definition, - part, - size, - totalParts, - true + private void downloadPartInRange( + long size, + int totalParts, + ModelLoaderUtils.HttpStreamChunker downloadChunker, + ExecutorService executorService, + RefCountingListener countingListener, + ActionListener rangeFullyDownloadedListener + ) { + assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) + : format( + "Model download must execute from [%s] but thread is [%s]", + MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, + Thread.currentThread().getName() ); - executeRequestIfNotCancelled(PutTrainedModelDefinitionPartAction.INSTANCE, modelPartRequest); + if (countingListener.isFailing()) { + rangeFullyDownloadedListener.onResponse(null); // the error has already been reported elsewhere + return; } - // get the last part, this time verify the checksum and size - BytesArray definition = chunkIterator.next(); + try { + throwIfTaskCancelled(); + var bytesAndIndex = downloadChunker.next(); + task.setProgress(totalParts, progressCounter.getAndIncrement()); - if (config.getSha256().equals(chunkIterator.getSha256()) == false) { - String message = format( - "Model sha256 checksums do not match, expected [%s] but got [%s]", - config.getSha256(), - chunkIterator.getSha256() - ); + indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes()); + } catch (Exception e) { + rangeFullyDownloadedListener.onFailure(e); + return; + } - throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); + if (downloadChunker.hasNext()) { + executorService.execute( + () -> downloadPartInRange( + size, + totalParts, + downloadChunker, + executorService, + countingListener, + rangeFullyDownloadedListener + ) + ); + } else { + rangeFullyDownloadedListener.onResponse(null); } + } - if (config.getSize() != chunkIterator.getTotalBytesRead()) { - String message = format( - "Model size does not match, expected [%d] but got [%d]", - config.getSize(), - chunkIterator.getTotalBytesRead() + private void downloadFinalPart( + long size, + int totalParts, + ModelLoaderUtils.HttpStreamChunker downloader, + ActionListener lastPartWrittenListener + ) { + assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) + : format( + "Model download must execute from [%s] but thread is [%s]", + MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, + Thread.currentThread().getName() ); - throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); + try { + var bytesAndIndex = downloader.next(); + task.setProgress(totalParts, progressCounter.getAndIncrement()); + + indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes()); + lastPartWrittenListener.onResponse(AcknowledgedResponse.TRUE); + } catch (Exception e) { + lastPartWrittenListener.onFailure(e); } + } - PutTrainedModelDefinitionPartAction.Request finalModelPartRequest = new PutTrainedModelDefinitionPartAction.Request( - modelId, - definition, - totalParts - 1, - size, - totalParts, - true - ); + void readModelDefinitionFromFile( + long size, + int totalParts, + ModelLoaderUtils.InputStreamChunker chunkIterator, + @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, + ActionListener finalListener + ) { + try (var countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> executorService.execute(() -> { + finalListener.onResponse(AcknowledgedResponse.TRUE); + }), finalListener::onFailure))) { + try { + if (vocabularyParts != null) { + uploadVocabulary(vocabularyParts, countingListener); + } - executeRequestIfNotCancelled(PutTrainedModelDefinitionPartAction.INSTANCE, finalModelPartRequest); - logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); - } + for (int part = 0; part < totalParts; ++part) { + throwIfTaskCancelled(); + task.setProgress(totalParts, part); + BytesArray definition = chunkIterator.next(); + indexPart(part, totalParts, size, definition); + } + task.setProgress(totalParts, totalParts); - private void uploadVocabulary() throws URISyntaxException { - ModelLoaderUtils.VocabularyParts vocabularyParts = ModelLoaderUtils.loadVocabulary( - ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile()) - ); + checkDownloadComplete(chunkIterator, totalParts); + } catch (Exception e) { + countingListener.acquire().onFailure(e); + } + } + } + private void uploadVocabulary(ModelLoaderUtils.VocabularyParts vocabularyParts, RefCountingListener countingListener) { PutTrainedModelVocabularyAction.Request request = new PutTrainedModelVocabularyAction.Request( modelId, vocabularyParts.vocab(), @@ -136,17 +277,58 @@ private void uploadVocabulary() throws URISyntaxException { true ); - executeRequestIfNotCancelled(PutTrainedModelVocabularyAction.INSTANCE, request); + client.execute(PutTrainedModelVocabularyAction.INSTANCE, request, countingListener.acquire(r -> { + logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile())); + })); } - private void executeRequestIfNotCancelled( - ActionType action, - Request request - ) { - if (task.isCancelled()) { - throw new TaskCancelledException(format("task cancelled with reason [%s]", task.getReasonCancelled())); + private void indexPart(int partIndex, int totalParts, long totalSize, BytesArray bytes) { + PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request( + modelId, + bytes, + partIndex, + totalSize, + totalParts, + true + ); + + client.execute(PutTrainedModelDefinitionPartAction.INSTANCE, modelPartRequest).actionGet(); + } + + private void checkDownloadComplete(List downloaders) { + long totalBytesRead = downloaders.stream().mapToLong(ModelLoaderUtils.HttpStreamChunker::getTotalBytesRead).sum(); + int totalParts = downloaders.stream().mapToInt(ModelLoaderUtils.HttpStreamChunker::getCurrentPart).sum(); + checkSize(totalBytesRead); + logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); + } + + private void checkDownloadComplete(ModelLoaderUtils.InputStreamChunker fileInputStream, int totalParts) { + checkSha256(fileInputStream.getSha256()); + checkSize(fileInputStream.getTotalBytesRead()); + logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); + } + + private void checkSha256(String sha256) { + if (config.getSha256().equals(sha256) == false) { + String message = format("Model sha256 checksums do not match, expected [%s] but got [%s]", config.getSha256(), sha256); + + throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); } + } - client.execute(action, request).actionGet(); + private void checkSize(long definitionSize) { + if (config.getSize() != definitionSize) { + String message = format("Model size does not match, expected [%d] but got [%d]", config.getSize(), definitionSize); + throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); + } + } + + private void throwIfTaskCancelled() { + if (task.isCancelled()) { + logger.info("Model [{}] download task cancelled", modelId); + throw new TaskCancelledException( + format("Model [%s] download task cancelled with reason [%s]", modelId, task.getReasonCancelled()) + ); + } } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java index 2f3f9cbf3f32c..e92aff74be463 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; @@ -34,16 +35,20 @@ import java.security.AccessController; import java.security.MessageDigest; import java.security.PrivilegedAction; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import static java.net.HttpURLConnection.HTTP_MOVED_PERM; import static java.net.HttpURLConnection.HTTP_MOVED_TEMP; import static java.net.HttpURLConnection.HTTP_NOT_FOUND; import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_PARTIAL; import static java.net.HttpURLConnection.HTTP_SEE_OTHER; /** @@ -61,6 +66,75 @@ final class ModelLoaderUtils { record VocabularyParts(List vocab, List merges, List scores) {} + // Range in bytes + record RequestRange(long rangeStart, long rangeEnd, int startPart, int numParts) { + public String bytesRange() { + return "bytes=" + rangeStart + "-" + rangeEnd; + } + } + + static class HttpStreamChunker { + + record BytesAndPartIndex(BytesArray bytes, int partIndex) {} + + private final InputStream inputStream; + private final int chunkSize; + private final AtomicLong totalBytesRead = new AtomicLong(); + private final AtomicInteger currentPart; + private final int lastPartNumber; + private final byte[] buf; + + HttpStreamChunker(URI uri, RequestRange range, int chunkSize) { + var inputStream = getHttpOrHttpsInputStream(uri, range); + this.inputStream = inputStream; + this.chunkSize = chunkSize; + this.lastPartNumber = range.startPart() + range.numParts(); + this.currentPart = new AtomicInteger(range.startPart()); + this.buf = new byte[chunkSize]; + } + + // This ctor exists for testing purposes only. + HttpStreamChunker(InputStream inputStream, RequestRange range, int chunkSize) { + this.inputStream = inputStream; + this.chunkSize = chunkSize; + this.lastPartNumber = range.startPart() + range.numParts(); + this.currentPart = new AtomicInteger(range.startPart()); + this.buf = new byte[chunkSize]; + } + + public boolean hasNext() { + return currentPart.get() < lastPartNumber; + } + + public BytesAndPartIndex next() throws IOException { + int bytesRead = 0; + + while (bytesRead < chunkSize) { + int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead); + // EOF?? + if (read == -1) { + break; + } + bytesRead += read; + } + + if (bytesRead > 0) { + totalBytesRead.addAndGet(bytesRead); + return new BytesAndPartIndex(new BytesArray(buf, 0, bytesRead), currentPart.getAndIncrement()); + } else { + return new BytesAndPartIndex(BytesArray.EMPTY, currentPart.get()); + } + } + + public long getTotalBytesRead() { + return totalBytesRead.get(); + } + + public int getCurrentPart() { + return currentPart.get(); + } + } + static class InputStreamChunker { private final InputStream inputStream; @@ -101,14 +175,14 @@ public int getTotalBytesRead() { } } - static InputStream getInputStreamFromModelRepository(URI uri) throws IOException { + static InputStream getInputStreamFromModelRepository(URI uri) { String scheme = uri.getScheme().toLowerCase(Locale.ROOT); // if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository} switch (scheme) { case "http": case "https": - return getHttpOrHttpsInputStream(uri); + return getHttpOrHttpsInputStream(uri, null); case "file": return getFileInputStream(uri); default: @@ -116,6 +190,11 @@ static InputStream getInputStreamFromModelRepository(URI uri) throws IOException } } + static boolean uriIsFile(URI uri) { + String scheme = uri.getScheme().toLowerCase(Locale.ROOT); + return "file".equals(scheme); + } + static VocabularyParts loadVocabulary(URI uri) { if (uri.getPath().endsWith(".json")) { try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) { @@ -174,7 +253,7 @@ private ModelLoaderUtils() {} @SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") @SuppressForbidden(reason = "we need socket connection to download") - private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException { + private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestRange range) { assert uri.getUserInfo() == null : "URI's with credentials are not supported"; @@ -186,18 +265,30 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException PrivilegedAction privilegedHttpReader = () -> { try { HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection(); + if (range != null) { + conn.setRequestProperty("Range", range.bytesRange()); + } switch (conn.getResponseCode()) { case HTTP_OK: + case HTTP_PARTIAL: return conn.getInputStream(); + case HTTP_MOVED_PERM: case HTTP_MOVED_TEMP: case HTTP_SEE_OTHER: throw new IllegalStateException("redirects aren't supported yet"); case HTTP_NOT_FOUND: throw new ResourceNotFoundException("{} not found", uri); + case 416: // Range not satisfiable, for some reason not in the list of constants + throw new IllegalStateException("Invalid request range [" + range.bytesRange() + "]"); default: int responseCode = conn.getResponseCode(); - throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri); + throw new ElasticsearchStatusException( + "error during downloading {}. Got response code {}", + RestStatus.fromCode(responseCode), + uri, + responseCode + ); } } catch (IOException e) { throw new UncheckedIOException(e); @@ -209,7 +300,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException @SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") @SuppressForbidden(reason = "we need load model data from a file") - private static InputStream getFileInputStream(URI uri) { + static InputStream getFileInputStream(URI uri) { SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -232,4 +323,53 @@ private static InputStream getFileInputStream(URI uri) { return AccessController.doPrivileged(privilegedFileReader); } + /** + * Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1 + * ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a + * whole number of chunks. + * The first {@code numberOfStreams} ranges will be split evenly (in terms of + * number of chunks not the byte size), the final range split + * is for the single final chunk and will be no more than {@code chunkSizeBytes} + * in size. The separate range for the final chunk is because when streaming and + * uploading a large model definition, writing the last part has to handled + * as a special case. + * @param sizeInBytes The total size of the stream + * @param numberOfStreams Divide the bulk of the size into this many streams. + * @param chunkSizeBytes The size of each chunk + * @return List of {@code numberOfStreams} + 1 ranges. + */ + static List split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) { + int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes); + + var ranges = new ArrayList(); + + int baseChunksPerStream = numberOfChunks / numberOfStreams; + int remainder = numberOfChunks % numberOfStreams; + long startOffset = 0; + int startChunkIndex = 0; + + for (int i = 0; i < numberOfStreams - 1; i++) { + int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream; + long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based + ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream)); + startOffset = rangeEnd + 1; // range is inclusive start and end + startChunkIndex += numChunksInStream; + } + + // Want the final range request to be a single chunk + if (baseChunksPerStream > 1) { + int numChunksExcludingFinal = baseChunksPerStream - 1; + long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1; + ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal)); + + startOffset = rangeEnd + 1; + startChunkIndex += numChunksExcludingFinal; + } + + // The final range is a single chunk the end of which should not exceed sizeInBytes + long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1; + ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1)); + + return ranges; + } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java index ba50f2f6a6b74..68f869742d9e5 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java @@ -77,7 +77,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A String packagedModelId = request.getPackagedModelId(); logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository)); - threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME).execute(() -> { + threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME).execute(() -> { try { URI uri = ModelLoaderUtils.resolvePackageLocation(repository, packagedModelId + ModelLoaderUtils.METADATA_FILE_EXTENSION); InputStream inputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri); diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java index 70dcee165d3f6..76b7781b1cffe 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java @@ -23,6 +23,7 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -37,14 +38,12 @@ import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction.Request; -import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; import java.io.IOException; import java.net.MalformedURLException; import java.net.URISyntaxException; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -57,6 +56,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction< private static final Logger logger = LogManager.getLogger(TransportLoadTrainedModelPackage.class); private final Client client; + private final CircuitBreakerService circuitBreakerService; @Inject public TransportLoadTrainedModelPackage( @@ -65,7 +65,8 @@ public TransportLoadTrainedModelPackage( ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, - Client client + Client client, + CircuitBreakerService circuitBreakerService ) { super( LoadTrainedModelPackageAction.NAME, @@ -79,6 +80,7 @@ public TransportLoadTrainedModelPackage( EsExecutors.DIRECT_EXECUTOR_SERVICE ); this.client = new OriginSettingClient(client, ML_ORIGIN); + this.circuitBreakerService = circuitBreakerService; } @Override @@ -98,11 +100,14 @@ protected void masterOperation(Task task, Request request, ClusterState state, A parentTaskAssigningClient, request.getModelId(), request.getModelPackageConfig(), - downloadTask + downloadTask, + threadPool, + circuitBreakerService ); - threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME) - .execute(() -> importModel(client, taskManager, request, modelImporter, listener, downloadTask)); + var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.noop(); + + importModel(client, taskManager, request, modelImporter, downloadCompleteListener, downloadTask); } catch (Exception e) { taskManager.unregister(downloadTask); listener.onFailure(e); @@ -136,16 +141,12 @@ static void importModel( ActionListener listener, Task task ) { - String modelId = request.getModelId(); - final AtomicReference exceptionRef = new AtomicReference<>(); - - try { - final long relativeStartNanos = System.nanoTime(); + final String modelId = request.getModelId(); + final long relativeStartNanos = System.nanoTime(); - logAndWriteNotificationAtLevel(auditClient, modelId, "starting model import", Level.INFO); - - modelImporter.doImport(); + logAndWriteNotificationAtLevel(auditClient, modelId, "starting model import", Level.INFO); + var finishListener = ActionListener.wrap(success -> { final long totalRuntimeNanos = System.nanoTime() - relativeStartNanos; logAndWriteNotificationAtLevel( auditClient, @@ -153,29 +154,25 @@ static void importModel( format("finished model import after [%d] seconds", TimeUnit.NANOSECONDS.toSeconds(totalRuntimeNanos)), Level.INFO ); - } catch (TaskCancelledException e) { - recordError(auditClient, modelId, exceptionRef, e, Level.WARNING); - } catch (ElasticsearchException e) { - recordError(auditClient, modelId, exceptionRef, e, Level.ERROR); - } catch (MalformedURLException e) { - recordError(auditClient, modelId, "an invalid URL", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); - } catch (URISyntaxException e) { - recordError(auditClient, modelId, "an invalid URL syntax", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); - } catch (IOException e) { - recordError(auditClient, modelId, "an IOException", exceptionRef, e, Level.ERROR, RestStatus.SERVICE_UNAVAILABLE); - } catch (Exception e) { - recordError(auditClient, modelId, "an Exception", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); - } finally { - taskManager.unregister(task); - - if (request.isWaitForCompletion()) { - if (exceptionRef.get() != null) { - listener.onFailure(exceptionRef.get()); - } else { - listener.onResponse(AcknowledgedResponse.TRUE); - } + listener.onResponse(AcknowledgedResponse.TRUE); + }, exception -> listener.onFailure(processException(auditClient, modelId, exception))); + + modelImporter.doImport(ActionListener.runAfter(finishListener, () -> taskManager.unregister(task))); + } - } + static Exception processException(Client auditClient, String modelId, Exception e) { + if (e instanceof TaskCancelledException te) { + return recordError(auditClient, modelId, te, Level.WARNING); + } else if (e instanceof ElasticsearchException es) { + return recordError(auditClient, modelId, es, Level.ERROR); + } else if (e instanceof MalformedURLException) { + return recordError(auditClient, modelId, "an invalid URL", e, Level.ERROR, RestStatus.BAD_REQUEST); + } else if (e instanceof URISyntaxException) { + return recordError(auditClient, modelId, "an invalid URL syntax", e, Level.ERROR, RestStatus.BAD_REQUEST); + } else if (e instanceof IOException) { + return recordError(auditClient, modelId, "an IOException", e, Level.ERROR, RestStatus.SERVICE_UNAVAILABLE); + } else { + return recordError(auditClient, modelId, "an Exception", e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); } } @@ -213,30 +210,16 @@ public ModelDownloadTask createTask(long id, String type, String action, TaskId } } - private static void recordError( - Client client, - String modelId, - AtomicReference exceptionRef, - ElasticsearchException e, - Level level - ) { + private static Exception recordError(Client client, String modelId, ElasticsearchException e, Level level) { String message = format("Model importing failed due to [%s]", e.getDetailedMessage()); logAndWriteNotificationAtLevel(client, modelId, message, level); - exceptionRef.set(e); + return e; } - private static void recordError( - Client client, - String modelId, - String failureType, - AtomicReference exceptionRef, - Exception e, - Level level, - RestStatus status - ) { + private static Exception recordError(Client client, String modelId, String failureType, Exception e, Level level, RestStatus status) { String message = format("Model importing failed due to %s [%s]", failureType, e); logAndWriteNotificationAtLevel(client, modelId, message, level); - exceptionRef.set(new ElasticsearchStatusException(message, status, e)); + return new ElasticsearchStatusException(message, status, e); } private static void logAndWriteNotificationAtLevel(Client client, String modelId, String message, Level level) { diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java index 967d1b4ba4b6a..2e487b6a9624c 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java @@ -7,9 +7,13 @@ package org.elasticsearch.xpack.ml.packageloader; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.PathUtils; import org.elasticsearch.test.ESTestCase; +import java.util.List; + import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -80,4 +84,12 @@ public void testValidateModelRepository() { assertEquals("xpack.ml.model_repository does not support authentication", e.getMessage()); } + + public void testThreadPoolHasSingleThread() { + var fixedThreadPool = MachineLearningPackageLoader.modelDownloadExecutor(Settings.EMPTY); + List> settings = fixedThreadPool.getRegisteredSettings(); + var sizeSettting = settings.stream().filter(s -> s.getKey().startsWith("xpack.ml.model_download_thread_pool")).findFirst(); + assertTrue(sizeSettting.isPresent()); + assertEquals(5, sizeSettting.get().get(Settings.EMPTY)); + } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java index 0afd08c70cf45..3a682fb6a5094 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java @@ -20,14 +20,7 @@ public class ModelDownloadTaskTests extends ESTestCase { public void testStatus() { - var task = new ModelDownloadTask( - 0L, - MODEL_IMPORT_TASK_TYPE, - MODEL_IMPORT_TASK_ACTION, - downloadModelTaskDescription("foo"), - TaskId.EMPTY_TASK_ID, - Map.of() - ); + var task = testTask(); task.setProgress(100, 0); var taskInfo = task.taskInfo("node", true); @@ -39,4 +32,15 @@ public void testStatus() { status = Strings.toString(taskInfo.status()); assertThat(status, containsString("{\"total_parts\":100,\"downloaded_parts\":1}")); } + + public static ModelDownloadTask testTask() { + return new ModelDownloadTask( + 0L, + MODEL_IMPORT_TASK_TYPE, + MODEL_IMPORT_TASK_ACTION, + downloadModelTaskDescription("foo"), + TaskId.EMPTY_TASK_ID, + Map.of() + ); + } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java new file mode 100644 index 0000000000000..cbcf74e69f588 --- /dev/null +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java @@ -0,0 +1,334 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.packageloader.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.hash.MessageDigests; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; +import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; +import org.junit.After; +import org.junit.Before; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ModelImporterTests extends ESTestCase { + + private TestThreadPool threadPool; + + @Before + public void createThreadPool() { + threadPool = createThreadPool(MachineLearningPackageLoader.modelDownloadExecutor(Settings.EMPTY)); + } + + @After + public void closeThreadPool() { + threadPool.close(); + } + + public void testDownloadModelDefinition() throws InterruptedException, URISyntaxException { + var client = mockClient(false); + var task = ModelDownloadTaskTests.testTask(); + var config = mockConfigWithRepoLinks(); + var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); + var cbs = mock(CircuitBreakerService.class); + when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class)); + + int totalParts = 5; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); + + var digest = computeDigest(modelDef); + when(config.getSha256()).thenReturn(digest); + when(config.getSize()).thenReturn(size); + + var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener(ActionTestUtils.assertNoFailureListener(ignore -> {}), latch); + importer.downloadModelDefinition(size, totalParts, vocab, streamers, latchedListener); + + latch.await(); + verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any()); + assertEquals(totalParts - 1, task.getStatus().downloadProgress().downloadedParts()); + assertEquals(totalParts, task.getStatus().downloadProgress().totalParts()); + } + + public void testReadModelDefinitionFromFile() throws InterruptedException, URISyntaxException { + var client = mockClient(false); + var task = ModelDownloadTaskTests.testTask(); + var config = mockConfigWithRepoLinks(); + var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); + var cbs = mock(CircuitBreakerService.class); + when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class)); + + int totalParts = 3; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + + var digest = computeDigest(modelDef); + when(config.getSha256()).thenReturn(digest); + when(config.getSize()).thenReturn(size); + + var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs); + var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener(ActionTestUtils.assertNoFailureListener(ignore -> {}), latch); + importer.readModelDefinitionFromFile(size, totalParts, streamChunker, vocab, latchedListener); + + latch.await(); + verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any()); + assertEquals(totalParts, task.getStatus().downloadProgress().downloadedParts()); + assertEquals(totalParts, task.getStatus().downloadProgress().totalParts()); + } + + public void testSizeMismatch() throws InterruptedException, URISyntaxException { + var client = mockClient(false); + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + var cbs = mock(CircuitBreakerService.class); + when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class)); + + int totalParts = 5; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); + + var digest = computeDigest(modelDef); + when(config.getSha256()).thenReturn(digest); + when(config.getSize()).thenReturn(size - 1); // expected size and read size are different + + var exceptionHolder = new AtomicReference(); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs); + importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("Model size does not match")); + verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any()); + } + + public void testDigestMismatch() throws InterruptedException, URISyntaxException { + var client = mockClient(false); + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + var cbs = mock(CircuitBreakerService.class); + when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class)); + + int totalParts = 5; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); + + when(config.getSha256()).thenReturn("0x"); // digest is different + when(config.getSize()).thenReturn(size); + + var exceptionHolder = new AtomicReference(); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs); + // Message digest can only be calculated for the file reader + var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize); + importer.readModelDefinitionFromFile(size, totalParts, streamChunker, null, latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("Model sha256 checksums do not match")); + verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any()); + } + + public void testPutFailure() throws InterruptedException, URISyntaxException { + var client = mockClient(true); // client will fail put + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + var cbs = mock(CircuitBreakerService.class); + when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class)); + + int totalParts = 4; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 1); + + var exceptionHolder = new AtomicReference(); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs); + importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("put model part failed")); + verify(client, times(1)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any()); + } + + public void testReadFailure() throws IOException, InterruptedException, URISyntaxException { + var client = mockClient(true); + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + var cbs = mock(CircuitBreakerService.class); + when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class)); + + int totalParts = 4; + int chunkSize = 10; + long size = totalParts * chunkSize; + + var streamer = mock(ModelLoaderUtils.HttpStreamChunker.class); + when(streamer.hasNext()).thenReturn(true); + when(streamer.next()).thenThrow(new IOException("stream failed")); // fail the read + + var exceptionHolder = new AtomicReference(); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs); + importer.downloadModelDefinition(size, totalParts, null, List.of(streamer), latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("stream failed")); + } + + @SuppressWarnings("unchecked") + public void testUploadVocabFailure() throws InterruptedException, URISyntaxException { + var client = mock(Client.class); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new ElasticsearchStatusException("put vocab failed", RestStatus.BAD_REQUEST)); + return null; + }).when(client).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); + var cbs = mock(CircuitBreakerService.class); + when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class)); + + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + + var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); + + var exceptionHolder = new AtomicReference(); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs); + importer.downloadModelDefinition(100, 5, vocab, List.of(), latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("put vocab failed")); + verify(client, times(1)).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); + verify(client, never()).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any()); + } + + private List mockHttpStreamChunkers(byte[] modelDef, int chunkSize, int numStreams) { + var ranges = ModelLoaderUtils.split(modelDef.length, numStreams, chunkSize); + + var result = new ArrayList(ranges.size()); + for (var range : ranges) { + int len = range.numParts() * chunkSize; + var modelDefStream = new ByteArrayInputStream(modelDef, (int) range.rangeStart(), len); + result.add(new ModelLoaderUtils.HttpStreamChunker(modelDefStream, range, chunkSize)); + } + + return result; + } + + private byte[] modelDefinition(int totalParts, int chunkSize) { + var bytes = new byte[totalParts * chunkSize]; + for (int i = 0; i < totalParts; i++) { + System.arraycopy(randomByteArrayOfLength(chunkSize), 0, bytes, i * chunkSize, chunkSize); + } + return bytes; + } + + private String computeDigest(byte[] modelDef) { + var digest = MessageDigests.sha256(); + digest.update(modelDef); + return MessageDigests.toHexString(digest.digest()); + } + + @SuppressWarnings("unchecked") + private Client mockClient(boolean failPutPart) { + var client = mock(Client.class); + + if (failPutPart) { + when(client.execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any())).thenThrow( + new IllegalStateException("put model part failed") + ); + } else { + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(AcknowledgedResponse.TRUE); + when(client.execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any())).thenReturn(future); + } + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(AcknowledgedResponse.TRUE); + return null; + }).when(client).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); + + return client; + } + + private ModelPackageConfig mockConfigWithRepoLinks() { + var config = mock(ModelPackageConfig.class); + when(config.getModelRepository()).thenReturn("https://models.models"); + when(config.getPackagedModelId()).thenReturn("my-model"); + return config; + } +} diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java index 661cd12f99957..f421a7b44e7f1 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java @@ -17,6 +17,7 @@ import java.nio.charset.StandardCharsets; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.core.Is.is; public class ModelLoaderUtilsTests extends ESTestCase { @@ -80,14 +81,13 @@ public void testSha256AndSize() throws IOException { assertEquals(64, expectedDigest.length()); int chunkSize = randomIntBetween(100, 10_000); + int totalParts = (bytes.length + chunkSize - 1) / chunkSize; ModelLoaderUtils.InputStreamChunker inputStreamChunker = new ModelLoaderUtils.InputStreamChunker( new ByteArrayInputStream(bytes), chunkSize ); - int totalParts = (bytes.length + chunkSize - 1) / chunkSize; - for (int part = 0; part < totalParts - 1; ++part) { assertEquals(chunkSize, inputStreamChunker.next().length()); } @@ -112,4 +112,40 @@ public void testParseVocabulary() throws IOException { assertThat(parsedVocab.merges(), contains("mergefoo", "mergebar", "mergebaz")); assertThat(parsedVocab.scores(), contains(1.0, 2.0, 3.0)); } + + public void testSplitIntoRanges() { + long totalSize = randomLongBetween(10_000, 50_000_000); + int numStreams = randomIntBetween(1, 10); + int chunkSize = 1024; + var ranges = ModelLoaderUtils.split(totalSize, numStreams, chunkSize); + assertThat(ranges, hasSize(numStreams + 1)); + + int expectedNumChunks = (int) ((totalSize + chunkSize - 1) / chunkSize); + assertThat(ranges.stream().mapToInt(ModelLoaderUtils.RequestRange::numParts).sum(), is(expectedNumChunks)); + + long startBytes = 0; + int startPartIndex = 0; + for (int i = 0; i < ranges.size() - 1; i++) { + assertThat(ranges.get(i).rangeStart(), is(startBytes)); + long end = startBytes + ((long) ranges.get(i).numParts() * chunkSize) - 1; + assertThat(ranges.get(i).rangeEnd(), is(end)); + long expectedNumBytesInRange = (long) chunkSize * ranges.get(i).numParts() - 1; + assertThat(ranges.get(i).rangeEnd() - ranges.get(i).rangeStart(), is(expectedNumBytesInRange)); + assertThat(ranges.get(i).startPart(), is(startPartIndex)); + + startBytes = end + 1; + startPartIndex += ranges.get(i).numParts(); + } + + var finalRange = ranges.get(ranges.size() - 1); + assertThat(finalRange.rangeStart(), is(startBytes)); + assertThat(finalRange.rangeEnd(), is(totalSize - 1)); + assertThat(finalRange.numParts(), is(1)); + } + + public void testRangeRequestBytesRange() { + long start = randomLongBetween(0, 2 << 10); + long end = randomLongBetween(start + 1, 2 << 11); + assertEquals("bytes=" + start + "-" + end, new ModelLoaderUtils.RequestRange(start, end, 0, 1).bytesRange()); + } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java index a3f59e13f2f5b..cbcfd5b760779 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java @@ -33,7 +33,7 @@ import static org.hamcrest.core.Is.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -42,7 +42,7 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase { private static final String MODEL_IMPORT_FAILURE_MSG_FORMAT = "Model importing failed due to %s [%s]"; public void testSendsFinishedUploadNotification() { - var uploader = mock(ModelImporter.class); + var uploader = createUploader(null); var taskManager = mock(TaskManager.class); var task = mock(Task.class); var client = mock(Client.class); @@ -63,49 +63,49 @@ public void testSendsFinishedUploadNotification() { assertThat(notificationArg.getValue().getMessage(), CoreMatchers.containsString("finished model import after")); } - public void testSendsErrorNotificationForInternalError() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForInternalError() throws Exception { ElasticsearchStatusException exception = new ElasticsearchStatusException("exception", RestStatus.INTERNAL_SERVER_ERROR); String message = format("Model importing failed due to [%s]", exception.toString()); assertUploadCallsOnFailure(exception, message, Level.ERROR); } - public void testSendsErrorNotificationForMalformedURL() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForMalformedURL() throws Exception { MalformedURLException exception = new MalformedURLException("exception"); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); + assertUploadCallsOnFailure(exception, message, RestStatus.BAD_REQUEST, Level.ERROR); } - public void testSendsErrorNotificationForURISyntax() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForURISyntax() throws Exception { URISyntaxException exception = mock(URISyntaxException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL syntax", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); + assertUploadCallsOnFailure(exception, message, RestStatus.BAD_REQUEST, Level.ERROR); } - public void testSendsErrorNotificationForIOException() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForIOException() throws Exception { IOException exception = mock(IOException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an IOException", exception.toString()); assertUploadCallsOnFailure(exception, message, RestStatus.SERVICE_UNAVAILABLE, Level.ERROR); } - public void testSendsErrorNotificationForException() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForException() throws Exception { RuntimeException exception = mock(RuntimeException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an Exception", exception.toString()); assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); } - public void testSendsWarningNotificationForTaskCancelledException() throws URISyntaxException, IOException { + public void testSendsWarningNotificationForTaskCancelledException() throws Exception { TaskCancelledException exception = new TaskCancelledException("cancelled"); String message = format("Model importing failed due to [%s]", exception.toString()); assertUploadCallsOnFailure(exception, message, Level.WARNING); } - public void testCallsOnResponseWithAcknowledgedResponse() throws URISyntaxException, IOException { + public void testCallsOnResponseWithAcknowledgedResponse() throws Exception { var client = mock(Client.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); @@ -134,15 +134,13 @@ public void testDoesNotCallListenerWhenNotWaitingForCompletion() { ); } - private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws URISyntaxException, - IOException { + private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws Exception { var esStatusException = new ElasticsearchStatusException(message, status, exception); assertNotificationAndOnFailure(exception, esStatusException, message, level); } - private void assertUploadCallsOnFailure(ElasticsearchException exception, String message, Level level) throws URISyntaxException, - IOException { + private void assertUploadCallsOnFailure(ElasticsearchException exception, String message, Level level) throws Exception { assertNotificationAndOnFailure(exception, exception, message, level); } @@ -151,7 +149,7 @@ private void assertNotificationAndOnFailure( ElasticsearchException onFailureException, String message, Level level - ) throws URISyntaxException, IOException { + ) throws Exception { var client = mock(Client.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); @@ -179,11 +177,18 @@ private void assertNotificationAndOnFailure( verify(taskManager).unregister(task); } - private ModelImporter createUploader(Exception exception) throws URISyntaxException, IOException { + @SuppressWarnings("unchecked") + private ModelImporter createUploader(Exception exception) { ModelImporter uploader = mock(ModelImporter.class); - if (exception != null) { - doThrow(exception).when(uploader).doImport(); - } + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[0]; + if (exception != null) { + listener.onFailure(exception); + } else { + listener.onResponse(AcknowledgedResponse.TRUE); + } + return null; + }).when(uploader).doImport(any(ActionListener.class)); return uploader; }