Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
2e854e4
Write model parts async
davidkyle Aug 7, 2024
125c822
Update docs/changelog/111684.yaml
davidkyle Aug 7, 2024
be39082
Pass a listener to import
davidkyle Aug 9, 2024
5310a08
Ref counting WIP
davidkyle Aug 9, 2024
3a57e14
use ref counting listener
davidkyle Aug 14, 2024
56f5e1c
tidying
davidkyle Aug 14, 2024
9f98dbe
add tests
davidkyle Aug 14, 2024
e434499
Merge branch 'main' into background-download-write
elasticmachine Aug 14, 2024
9d9a0e6
Add download threadpool
davidkyle Aug 19, 2024
669909b
less blocking
davidkyle Aug 19, 2024
fcc66b4
tidy up
davidkyle Aug 19, 2024
1921812
more tests
davidkyle Aug 21, 2024
63bdff6
Merge branch 'main' into background-download-write
elasticmachine Aug 21, 2024
84aec83
remove unused
davidkyle Aug 22, 2024
c4ea146
fix the tests
davidkyle Aug 22, 2024
413a3ab
5 in flight requests
davidkyle Aug 22, 2024
70ffe07
use another threadpool for writes
davidkyle Aug 27, 2024
573d5af
Revert "use another threadpool for writes"
davidkyle Aug 29, 2024
658b903
use range request
davidkyle Sep 3, 2024
98d8020
Merge branch 'main' into background-download-write
davidkyle Sep 9, 2024
1d9f5b5
Use multiple connections
davidkyle Sep 9, 2024
d718a92
Merge branch 'main' into background-download-write
elasticmachine Sep 10, 2024
0395c76
Tidy comments
davidkyle Sep 10, 2024
d37cd3a
short threads
davidkyle Sep 10, 2024
6e88b09
Merge branch 'main' into background-download-write
elasticmachine Sep 11, 2024
73f0f6b
Enable markSupported in SlicedInputStream (#112563)
kingherc Sep 11, 2024
830c26f
Make two SubReaderWrapper implementations singletons (#112596)
original-brownbear Sep 11, 2024
02a9750
Two speedups to IndexNameExpressionResolver (#112486)
original-brownbear Sep 11, 2024
7f66489
Fix failing LangMustacheClientYamlTestSuiteIT yamlRestTestV7CompatTes…
cbuescher Sep 11, 2024
ae94a40
Fix trappy timeouts in downsample action (#112734)
DaveCTurner Sep 11, 2024
06a361a
Mute org.elasticsearch.smoketest.DocsClientYamlTestSuiteIT test {yaml…
elasticsearchmachine Sep 11, 2024
802ee00
Introduce repository integrity verification API (#112348)
DaveCTurner Sep 11, 2024
75c0fb9
JSON parse failures should be 4xx codes (#112703)
benwtrent Sep 11, 2024
07329d7
Handle null exception message in `TestCluster#wipe` (#112741)
DaveCTurner Sep 11, 2024
eabea6f
Add TaskManager to pluginServices (#112687)
parkertimmins Sep 11, 2024
051f504
Introduce data stream options and failure store configuration classes…
gmarouli Sep 11, 2024
a148619
Support widening of numeric types in union-types (#112610)
craigtaverner Sep 11, 2024
b9b62db
add CDR related data streams to kibana_system priviliges (#112655)
maxcold Sep 11, 2024
1a05488
Bump Elasticsearch version to 9.0.0 (#112570)
mark-vieira Sep 11, 2024
ca2b144
ESQL: Compute support for filtering ungrouped aggs (#112717)
nik9000 Sep 11, 2024
bceeced
Bump Elasticsearch to a minimum of JDK 21 (#112252)
ChrisHegarty Sep 11, 2024
a211d80
Mute org.elasticsearch.repositories.blobstore.testkit.integrity.Repos…
elasticsearchmachine Sep 11, 2024
cef6f0b
[DOCS] Augment installation warnings (#112756)
lcawl Sep 11, 2024
196728e
(Doc+) CAT Nodes default columns (#112715)
stefnestor Sep 11, 2024
c4932f2
(Doc+) Terminating Exit Codes (#112530)
stefnestor Sep 11, 2024
2b426f7
Fix verifyVersions task (#112765)
mark-vieira Sep 11, 2024
7b17077
(Doc+) Inference Pipeline ignores Mapping Analyzers (#112522)
stefnestor Sep 11, 2024
18a48c7
Estimate segment field usages (#112760)
dnhatn Sep 11, 2024
e0044a5
Use a dedicated test executor in MockTransportService (#112748)
ywangd Sep 12, 2024
2046b29
Mute org.elasticsearch.repositories.blobstore.testkit.integrity.Repos…
elasticsearchmachine Sep 12, 2024
9f5b528
Mute org.elasticsearch.script.StatsSummaryTests testEqualsAndHashCode…
elasticsearchmachine Sep 12, 2024
6fdb78c
Do not throw in task enqueued by CancellableRunner (#112780)
ywangd Sep 12, 2024
ca30b69
[Test] Account for auto-repairing for shard gen file (#112778)
ywangd Sep 12, 2024
1bbb739
Introduce test utils for ingest pipelines (#112733)
DaveCTurner Sep 12, 2024
6ef94ac
Deduplicate BucketOrder when deserializing (#112707)
iverase Sep 12, 2024
3fadf53
Block use of current version feature in yaml tests (#112737)
thecoop Sep 12, 2024
b801949
ci(bump automation): bump ubi9 for ironbank (#112298)
v1v Sep 12, 2024
c3c4aa5
Two empty mappings now are created equally (#107936)
piergm Sep 12, 2024
af1ba75
Cleanup shutdown module bwc in v9 (#112793)
DaveCTurner Sep 12, 2024
a14f529
Update last few references in yaml tests from ROOT locale to ENGLISH …
thecoop Sep 12, 2024
7aa98ef
Remove adaptive allocations feature flag (#112798)
jan-elastic Sep 12, 2024
352dd89
address comments
davidkyle Sep 12, 2024
a2c3c5e
Merge branch 'main' into background-download-write
elasticmachine Sep 12, 2024
d0fbe17
fix recovery from failure
davidkyle Sep 12, 2024
81ce3f1
Merge branch 'main' into background-download-write
elasticmachine Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/111684.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111684
summary: Write downloaded model parts async
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter;
import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage;

Expand All @@ -44,16 +49,15 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin
Setting.Property.Dynamic
);

// re-using thread pool setup by the ml plugin
public static final String UTILITY_THREAD_POOL_NAME = "ml_utility";

// This link will be invalid for serverless, but serverless will never be
// air-gapped, so this message should never be needed.
private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format(
"https://www.elastic.co/guide/en/machine-learning/%s/ml-nlp-elser.html#air-gapped-install",
Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1")
);

public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download";

public MachineLearningPackageLoader() {}

@Override
Expand Down Expand Up @@ -81,6 +85,24 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return List.of(modelDownloadExecutor(settings));
}

public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) {
// Threadpool with a fixed number of threads for
// downloading the model definition files
return new FixedExecutorBuilder(
settings,
MODEL_DOWNLOAD_THREADPOOL_NAME,
ModelImporter.NUMBER_OF_STREAMS,
-1, // unbounded queue size
"xpack.ml.model_download_thread_pool",
EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
);
}

@Override
public List<BootstrapCheck> getBootstrapChecks() {
return List.of(new BootstrapCheck() {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentParser;
Expand All @@ -34,16 +35,20 @@
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;

import static java.net.HttpURLConnection.HTTP_MOVED_PERM;
import static java.net.HttpURLConnection.HTTP_MOVED_TEMP;
import static java.net.HttpURLConnection.HTTP_NOT_FOUND;
import static java.net.HttpURLConnection.HTTP_OK;
import static java.net.HttpURLConnection.HTTP_PARTIAL;
import static java.net.HttpURLConnection.HTTP_SEE_OTHER;

/**
Expand All @@ -61,6 +66,73 @@ final class ModelLoaderUtils {

record VocabularyParts(List<String> vocab, List<String> merges, List<Double> scores) {}

// Range in bytes
record RequestRange(long rangeStart, long rangeEnd, int startPart, int numParts) {
public String bytesRange() {
return "bytes=" + rangeStart + "-" + rangeEnd;
}
}

static class HttpStreamChunker {

record BytesAndPartIndex(BytesArray bytes, int partIndex) {}

private final InputStream inputStream;
private final int chunkSize;
private final AtomicLong totalBytesRead = new AtomicLong();
private final AtomicInteger currentPart;
private final int lastPartNumber;

HttpStreamChunker(URI uri, RequestRange range, int chunkSize) {
var inputStream = getHttpOrHttpsInputStream(uri, range);
this.inputStream = inputStream;
this.chunkSize = chunkSize;
this.lastPartNumber = range.startPart() + range.numParts();
this.currentPart = new AtomicInteger(range.startPart());
}

// This ctor exists for testing purposes only.
HttpStreamChunker(InputStream inputStream, RequestRange range, int chunkSize) {
this.inputStream = inputStream;
this.chunkSize = chunkSize;
this.lastPartNumber = range.startPart() + range.numParts();
this.currentPart = new AtomicInteger(range.startPart());
}

public boolean hasNext() {
return currentPart.get() < lastPartNumber;
}

public BytesAndPartIndex next() throws IOException {
int bytesRead = 0;
byte[] buf = new byte[chunkSize];

while (bytesRead < chunkSize) {
int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead);
// EOF??
if (read == -1) {
break;
}
bytesRead += read;
}

if (bytesRead > 0) {
totalBytesRead.addAndGet(bytesRead);
return new BytesAndPartIndex(new BytesArray(buf, 0, bytesRead), currentPart.getAndIncrement());
} else {
return new BytesAndPartIndex(BytesArray.EMPTY, currentPart.get());
}
}

public long getTotalBytesRead() {
return totalBytesRead.get();
}

public int getCurrentPart() {
return currentPart.get();
}
}

static class InputStreamChunker {

private final InputStream inputStream;
Expand Down Expand Up @@ -101,21 +173,26 @@ public int getTotalBytesRead() {
}
}

static InputStream getInputStreamFromModelRepository(URI uri) throws IOException {
static InputStream getInputStreamFromModelRepository(URI uri) {
String scheme = uri.getScheme().toLowerCase(Locale.ROOT);

// if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository}
switch (scheme) {
case "http":
case "https":
return getHttpOrHttpsInputStream(uri);
return getHttpOrHttpsInputStream(uri, null);
case "file":
return getFileInputStream(uri);
default:
throw new IllegalArgumentException("unsupported scheme");
}
}

static boolean uriIsFile(URI uri) {
String scheme = uri.getScheme().toLowerCase(Locale.ROOT);
return "file".equals(scheme);
}

static VocabularyParts loadVocabulary(URI uri) {
if (uri.getPath().endsWith(".json")) {
try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) {
Expand Down Expand Up @@ -174,7 +251,7 @@ private ModelLoaderUtils() {}

@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
@SuppressForbidden(reason = "we need socket connection to download")
private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException {
private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestRange range) {

assert uri.getUserInfo() == null : "URI's with credentials are not supported";

Expand All @@ -186,18 +263,30 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException
PrivilegedAction<InputStream> privilegedHttpReader = () -> {
try {
HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection();
if (range != null) {
conn.setRequestProperty("Range", range.bytesRange());
}
switch (conn.getResponseCode()) {
case HTTP_OK:
case HTTP_PARTIAL:
return conn.getInputStream();

case HTTP_MOVED_PERM:
case HTTP_MOVED_TEMP:
case HTTP_SEE_OTHER:
throw new IllegalStateException("redirects aren't supported yet");
case HTTP_NOT_FOUND:
throw new ResourceNotFoundException("{} not found", uri);
case 416: // Range not satisfiable, for some reason not in the list of constants
throw new IllegalStateException("Invalid request range [" + range.bytesRange() + "]");
default:
int responseCode = conn.getResponseCode();
throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri);
throw new ElasticsearchStatusException(
"error during downloading {}. Got response code {}",
RestStatus.fromCode(responseCode),
uri,
responseCode
);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
Expand All @@ -209,7 +298,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException

@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
@SuppressForbidden(reason = "we need load model data from a file")
private static InputStream getFileInputStream(URI uri) {
static InputStream getFileInputStream(URI uri) {

SecurityManager sm = System.getSecurityManager();
if (sm != null) {
Expand All @@ -232,4 +321,53 @@ private static InputStream getFileInputStream(URI uri) {
return AccessController.doPrivileged(privilegedFileReader);
}

/**
* Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1
* ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a
* whole number of chunks.
* The first {@code numberOfStreams} ranges will be split evenly (in terms of
* number of chunks not the byte size), the final range split
* is for the single final chunk and will be no more than {@code chunkSizeBytes}
* in size. The separate range for the final chunk is because when streaming and
* uploading a large model definition, writing the last part has to handled
* as a special case.
* @param sizeInBytes The total size of the stream
* @param numberOfStreams Divide the bulk of the size into this many streams.
* @param chunkSizeBytes The size of each chunk
* @return List of {@code numberOfStreams} + 1 ranges.
*/
static List<RequestRange> split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) {
int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes);

var ranges = new ArrayList<RequestRange>();

int baseChunksPerStream = numberOfChunks / numberOfStreams;
int remainder = numberOfChunks % numberOfStreams;
long startOffset = 0;
int startChunkIndex = 0;

for (int i = 0; i < numberOfStreams - 1; i++) {
int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream;
long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream));
startOffset = rangeEnd + 1; // range is inclusive start and end
startChunkIndex += numChunksInStream;
}

// Want the final range request to be a single chunk
if (baseChunksPerStream > 1) {
int numChunksExcludingFinal = baseChunksPerStream - 1;
long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal));

startOffset = rangeEnd + 1;
startChunkIndex += numChunksExcludingFinal;
}

// The final range is a single chunk the end of which should not exceed sizeInBytes
long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1));

return ranges;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A
String packagedModelId = request.getPackagedModelId();
logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository));

threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME).execute(() -> {
threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME).execute(() -> {
try {
URI uri = ModelLoaderUtils.resolvePackageLocation(repository, packagedModelId + ModelLoaderUtils.METADATA_FILE_EXTENSION);
InputStream inputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri);
Expand Down
Loading
Loading