Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/111684.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111684
summary: Write downloaded model parts async
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter;
import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage;

Expand All @@ -44,16 +49,15 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin
Setting.Property.Dynamic
);

// re-using thread pool setup by the ml plugin
public static final String UTILITY_THREAD_POOL_NAME = "ml_utility";

// This link will be invalid for serverless, but serverless will never be
// air-gapped, so this message should never be needed.
private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format(
"https://www.elastic.co/guide/en/machine-learning/%s/ml-nlp-elser.html#air-gapped-install",
Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1")
);

public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download";

public MachineLearningPackageLoader() {}

@Override
Expand Down Expand Up @@ -81,6 +85,24 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return List.of(modelDownloadExecutor(settings));
}

public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) {
// Threadpool with a fixed number of threads for
// downloading the model definition files
return new FixedExecutorBuilder(
settings,
MODEL_DOWNLOAD_THREADPOOL_NAME,
ModelImporter.NUMBER_OF_STREAMS,
-1, // unbounded queue size
"xpack.ml.model_download_thread_pool",
EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
);
}

@Override
public List<BootstrapCheck> getBootstrapChecks() {
return List.of(new BootstrapCheck() {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentParser;
Expand All @@ -34,16 +35,20 @@
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;

import static java.net.HttpURLConnection.HTTP_MOVED_PERM;
import static java.net.HttpURLConnection.HTTP_MOVED_TEMP;
import static java.net.HttpURLConnection.HTTP_NOT_FOUND;
import static java.net.HttpURLConnection.HTTP_OK;
import static java.net.HttpURLConnection.HTTP_PARTIAL;
import static java.net.HttpURLConnection.HTTP_SEE_OTHER;

/**
Expand All @@ -61,6 +66,73 @@ final class ModelLoaderUtils {

record VocabularyParts(List<String> vocab, List<String> merges, List<Double> scores) {}

// Range in bytes
record RequestRange(long rangeStart, long rangeEnd, int startPart, int numParts) {
public String bytesRange() {
return "bytes=" + rangeStart + "-" + rangeEnd;
}
}

static class HttpStreamChunker {

record BytesAndPartIndex(BytesArray bytes, int partIndex) {}

private final InputStream inputStream;
private final int chunkSize;
private final AtomicLong totalBytesRead = new AtomicLong();
private final AtomicInteger currentPart;
private final int lastPartNumber;

HttpStreamChunker(URI uri, RequestRange range, int chunkSize) {
var inputStream = getHttpOrHttpsInputStream(uri, range);
this.inputStream = inputStream;
this.chunkSize = chunkSize;
this.lastPartNumber = range.startPart() + range.numParts();
this.currentPart = new AtomicInteger(range.startPart());
}

// This ctor exists for testing purposes only.
HttpStreamChunker(InputStream inputStream, RequestRange range, int chunkSize) {
this.inputStream = inputStream;
this.chunkSize = chunkSize;
this.lastPartNumber = range.startPart() + range.numParts();
this.currentPart = new AtomicInteger(range.startPart());
}

public boolean hasNext() {
return currentPart.get() < lastPartNumber;
}

public BytesAndPartIndex next() throws IOException {
int bytesRead = 0;
byte[] buf = new byte[chunkSize];

while (bytesRead < chunkSize) {
int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead);
// EOF??
if (read == -1) {
break;
}
bytesRead += read;
}

if (bytesRead > 0) {
totalBytesRead.addAndGet(bytesRead);
return new BytesAndPartIndex(new BytesArray(buf, 0, bytesRead), currentPart.getAndIncrement());
} else {
return new BytesAndPartIndex(BytesArray.EMPTY, currentPart.get());
}
}

public long getTotalBytesRead() {
return totalBytesRead.get();
}

public int getCurrentPart() {
return currentPart.get();
}
}

static class InputStreamChunker {

private final InputStream inputStream;
Expand Down Expand Up @@ -101,21 +173,26 @@ public int getTotalBytesRead() {
}
}

static InputStream getInputStreamFromModelRepository(URI uri) throws IOException {
static InputStream getInputStreamFromModelRepository(URI uri) {
String scheme = uri.getScheme().toLowerCase(Locale.ROOT);

// if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository}
switch (scheme) {
case "http":
case "https":
return getHttpOrHttpsInputStream(uri);
return getHttpOrHttpsInputStream(uri, null);
case "file":
return getFileInputStream(uri);
default:
throw new IllegalArgumentException("unsupported scheme");
}
}

static boolean uriIsFile(URI uri) {
String scheme = uri.getScheme().toLowerCase(Locale.ROOT);
return "file".equals(scheme);
}

static VocabularyParts loadVocabulary(URI uri) {
if (uri.getPath().endsWith(".json")) {
try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) {
Expand Down Expand Up @@ -174,7 +251,7 @@ private ModelLoaderUtils() {}

@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
@SuppressForbidden(reason = "we need socket connection to download")
private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException {
private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestRange range) {

assert uri.getUserInfo() == null : "URI's with credentials are not supported";

Expand All @@ -186,18 +263,30 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException
PrivilegedAction<InputStream> privilegedHttpReader = () -> {
try {
HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection();
if (range != null) {
conn.setRequestProperty("Range", range.bytesRange());
}
switch (conn.getResponseCode()) {
case HTTP_OK:
case HTTP_PARTIAL:
return conn.getInputStream();

case HTTP_MOVED_PERM:
case HTTP_MOVED_TEMP:
case HTTP_SEE_OTHER:
throw new IllegalStateException("redirects aren't supported yet");
case HTTP_NOT_FOUND:
throw new ResourceNotFoundException("{} not found", uri);
case 416: // Range not satisfiable, for some reason not in the list of constants
throw new IllegalStateException("Invalid request range [" + range.bytesRange() + "]");
default:
int responseCode = conn.getResponseCode();
throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri);
throw new ElasticsearchStatusException(
"error during downloading {}. Got response code {}",
RestStatus.fromCode(responseCode),
uri,
responseCode
);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
Expand All @@ -209,7 +298,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException

@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
@SuppressForbidden(reason = "we need load model data from a file")
private static InputStream getFileInputStream(URI uri) {
static InputStream getFileInputStream(URI uri) {

SecurityManager sm = System.getSecurityManager();
if (sm != null) {
Expand All @@ -232,4 +321,53 @@ private static InputStream getFileInputStream(URI uri) {
return AccessController.doPrivileged(privilegedFileReader);
}

/**
* Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1
* ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a
* whole number of chunks.
* The first {@code numberOfStreams} ranges will be split evenly (in terms of
* number of chunks not the byte size), the final range split
* is for the single final chunk and will be no more than {@code chunkSizeBytes}
* in size. The separate range for the final chunk is because when streaming and
* uploading a large model definition, writing the last part has to handled
* as a special case.
* @param sizeInBytes The total size of the stream
* @param numberOfStreams Divide the bulk of the size into this many streams.
* @param chunkSizeBytes The size of each chunk
* @return List of {@code numberOfStreams} + 1 ranges.
*/
static List<RequestRange> split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) {
int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes);

var ranges = new ArrayList<RequestRange>();

int baseChunksPerStream = numberOfChunks / numberOfStreams;
int remainder = numberOfChunks % numberOfStreams;
long startOffset = 0;
int startChunkIndex = 0;

for (int i = 0; i < numberOfStreams - 1; i++) {
int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream;
long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream));
startOffset = rangeEnd + 1; // range is inclusive start and end
startChunkIndex += numChunksInStream;
}

// Want the final range request to be a single chunk
if (baseChunksPerStream > 1) {
int numChunksExcludingFinal = baseChunksPerStream - 1;
long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal));

startOffset = rangeEnd + 1;
startChunkIndex += numChunksExcludingFinal;
}

// The final range is a single chunk the end of which should not exceed sizeInBytes
long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1));

return ranges;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A
String packagedModelId = request.getPackagedModelId();
logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository));

threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME).execute(() -> {
threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME).execute(() -> {
try {
URI uri = ModelLoaderUtils.resolvePackageLocation(repository, packagedModelId + ModelLoaderUtils.METADATA_FILE_EXTENSION);
InputStream inputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri);
Expand Down
Loading