Skip to content

Commit 4fe2851

Browse files
authored
[ML] Downloaded and write model parts using multiple streams (elastic#111684) (elastic#112859)
Uses the range header to split the model download into multiple streams using a separate thread for each stream
1 parent 53ff0ac commit 4fe2851

File tree

11 files changed

+847
-166
lines changed

11 files changed

+847
-166
lines changed

docs/changelog/111684.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 111684
2+
summary: Write downloaded model parts async
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
import org.elasticsearch.common.Strings;
1616
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1717
import org.elasticsearch.common.settings.Setting;
18+
import org.elasticsearch.common.settings.Settings;
19+
import org.elasticsearch.common.util.concurrent.EsExecutors;
1820
import org.elasticsearch.plugins.ActionPlugin;
1921
import org.elasticsearch.plugins.Plugin;
2022
import org.elasticsearch.tasks.Task;
23+
import org.elasticsearch.threadpool.ExecutorBuilder;
24+
import org.elasticsearch.threadpool.FixedExecutorBuilder;
2125
import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction;
2226
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
2327
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
28+
import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter;
2429
import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction;
2530
import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage;
2631

@@ -44,16 +49,15 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin
4449
Setting.Property.Dynamic
4550
);
4651

47-
// re-using thread pool setup by the ml plugin
48-
public static final String UTILITY_THREAD_POOL_NAME = "ml_utility";
49-
5052
// This link will be invalid for serverless, but serverless will never be
5153
// air-gapped, so this message should never be needed.
5254
private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format(
5355
"https://www.elastic.co/guide/en/machine-learning/%s/ml-nlp-elser.html#air-gapped-install",
5456
Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1")
5557
);
5658

59+
public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download";
60+
5761
public MachineLearningPackageLoader() {}
5862

5963
@Override
@@ -81,6 +85,24 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
8185
);
8286
}
8387

88+
@Override
89+
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
90+
return List.of(modelDownloadExecutor(settings));
91+
}
92+
93+
public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) {
94+
// Threadpool with a fixed number of threads for
95+
// downloading the model definition files
96+
return new FixedExecutorBuilder(
97+
settings,
98+
MODEL_DOWNLOAD_THREADPOOL_NAME,
99+
ModelImporter.NUMBER_OF_STREAMS,
100+
-1, // unbounded queue size
101+
"xpack.ml.model_download_thread_pool",
102+
EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
103+
);
104+
}
105+
84106
@Override
85107
public List<BootstrapCheck> getBootstrapChecks() {
86108
return List.of(new BootstrapCheck() {

x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java

Lines changed: 239 additions & 74 deletions
Large diffs are not rendered by default.

x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java

Lines changed: 143 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.common.io.Streams;
1818
import org.elasticsearch.common.unit.ByteSizeUnit;
1919
import org.elasticsearch.common.unit.ByteSizeValue;
20+
import org.elasticsearch.core.Nullable;
2021
import org.elasticsearch.core.SuppressForbidden;
2122
import org.elasticsearch.rest.RestStatus;
2223
import org.elasticsearch.xcontent.XContentParser;
@@ -34,16 +35,20 @@
3435
import java.security.AccessController;
3536
import java.security.MessageDigest;
3637
import java.security.PrivilegedAction;
38+
import java.util.ArrayList;
3739
import java.util.HashMap;
3840
import java.util.List;
3941
import java.util.Locale;
4042
import java.util.Map;
43+
import java.util.concurrent.atomic.AtomicInteger;
44+
import java.util.concurrent.atomic.AtomicLong;
4145
import java.util.stream.Collectors;
4246

4347
import static java.net.HttpURLConnection.HTTP_MOVED_PERM;
4448
import static java.net.HttpURLConnection.HTTP_MOVED_TEMP;
4549
import static java.net.HttpURLConnection.HTTP_NOT_FOUND;
4650
import static java.net.HttpURLConnection.HTTP_OK;
51+
import static java.net.HttpURLConnection.HTTP_PARTIAL;
4752
import static java.net.HttpURLConnection.HTTP_SEE_OTHER;
4853

4954
/**
@@ -61,6 +66,73 @@ final class ModelLoaderUtils {
6166

6267
record VocabularyParts(List<String> vocab, List<String> merges, List<Double> scores) {}
6368

69+
// Range in bytes
70+
record RequestRange(long rangeStart, long rangeEnd, int startPart, int numParts) {
71+
public String bytesRange() {
72+
return "bytes=" + rangeStart + "-" + rangeEnd;
73+
}
74+
}
75+
76+
static class HttpStreamChunker {
77+
78+
record BytesAndPartIndex(BytesArray bytes, int partIndex) {}
79+
80+
private final InputStream inputStream;
81+
private final int chunkSize;
82+
private final AtomicLong totalBytesRead = new AtomicLong();
83+
private final AtomicInteger currentPart;
84+
private final int lastPartNumber;
85+
86+
HttpStreamChunker(URI uri, RequestRange range, int chunkSize) {
87+
var inputStream = getHttpOrHttpsInputStream(uri, range);
88+
this.inputStream = inputStream;
89+
this.chunkSize = chunkSize;
90+
this.lastPartNumber = range.startPart() + range.numParts();
91+
this.currentPart = new AtomicInteger(range.startPart());
92+
}
93+
94+
// This ctor exists for testing purposes only.
95+
HttpStreamChunker(InputStream inputStream, RequestRange range, int chunkSize) {
96+
this.inputStream = inputStream;
97+
this.chunkSize = chunkSize;
98+
this.lastPartNumber = range.startPart() + range.numParts();
99+
this.currentPart = new AtomicInteger(range.startPart());
100+
}
101+
102+
public boolean hasNext() {
103+
return currentPart.get() < lastPartNumber;
104+
}
105+
106+
public BytesAndPartIndex next() throws IOException {
107+
int bytesRead = 0;
108+
byte[] buf = new byte[chunkSize];
109+
110+
while (bytesRead < chunkSize) {
111+
int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead);
112+
// EOF??
113+
if (read == -1) {
114+
break;
115+
}
116+
bytesRead += read;
117+
}
118+
119+
if (bytesRead > 0) {
120+
totalBytesRead.addAndGet(bytesRead);
121+
return new BytesAndPartIndex(new BytesArray(buf, 0, bytesRead), currentPart.getAndIncrement());
122+
} else {
123+
return new BytesAndPartIndex(BytesArray.EMPTY, currentPart.get());
124+
}
125+
}
126+
127+
public long getTotalBytesRead() {
128+
return totalBytesRead.get();
129+
}
130+
131+
public int getCurrentPart() {
132+
return currentPart.get();
133+
}
134+
}
135+
64136
static class InputStreamChunker {
65137

66138
private final InputStream inputStream;
@@ -101,21 +173,26 @@ public int getTotalBytesRead() {
101173
}
102174
}
103175

104-
static InputStream getInputStreamFromModelRepository(URI uri) throws IOException {
176+
static InputStream getInputStreamFromModelRepository(URI uri) {
105177
String scheme = uri.getScheme().toLowerCase(Locale.ROOT);
106178

107179
// if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository}
108180
switch (scheme) {
109181
case "http":
110182
case "https":
111-
return getHttpOrHttpsInputStream(uri);
183+
return getHttpOrHttpsInputStream(uri, null);
112184
case "file":
113185
return getFileInputStream(uri);
114186
default:
115187
throw new IllegalArgumentException("unsupported scheme");
116188
}
117189
}
118190

191+
static boolean uriIsFile(URI uri) {
192+
String scheme = uri.getScheme().toLowerCase(Locale.ROOT);
193+
return "file".equals(scheme);
194+
}
195+
119196
static VocabularyParts loadVocabulary(URI uri) {
120197
if (uri.getPath().endsWith(".json")) {
121198
try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) {
@@ -174,7 +251,7 @@ private ModelLoaderUtils() {}
174251

175252
@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
176253
@SuppressForbidden(reason = "we need socket connection to download")
177-
private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException {
254+
private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestRange range) {
178255

179256
assert uri.getUserInfo() == null : "URI's with credentials are not supported";
180257

@@ -186,18 +263,30 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException
186263
PrivilegedAction<InputStream> privilegedHttpReader = () -> {
187264
try {
188265
HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection();
266+
if (range != null) {
267+
conn.setRequestProperty("Range", range.bytesRange());
268+
}
189269
switch (conn.getResponseCode()) {
190270
case HTTP_OK:
271+
case HTTP_PARTIAL:
191272
return conn.getInputStream();
273+
192274
case HTTP_MOVED_PERM:
193275
case HTTP_MOVED_TEMP:
194276
case HTTP_SEE_OTHER:
195277
throw new IllegalStateException("redirects aren't supported yet");
196278
case HTTP_NOT_FOUND:
197279
throw new ResourceNotFoundException("{} not found", uri);
280+
case 416: // Range not satisfiable, for some reason not in the list of constants
281+
throw new IllegalStateException("Invalid request range [" + range.bytesRange() + "]");
198282
default:
199283
int responseCode = conn.getResponseCode();
200-
throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri);
284+
throw new ElasticsearchStatusException(
285+
"error during downloading {}. Got response code {}",
286+
RestStatus.fromCode(responseCode),
287+
uri,
288+
responseCode
289+
);
201290
}
202291
} catch (IOException e) {
203292
throw new UncheckedIOException(e);
@@ -209,7 +298,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException
209298

210299
@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
211300
@SuppressForbidden(reason = "we need load model data from a file")
212-
private static InputStream getFileInputStream(URI uri) {
301+
static InputStream getFileInputStream(URI uri) {
213302

214303
SecurityManager sm = System.getSecurityManager();
215304
if (sm != null) {
@@ -232,4 +321,53 @@ private static InputStream getFileInputStream(URI uri) {
232321
return AccessController.doPrivileged(privilegedFileReader);
233322
}
234323

324+
/**
325+
* Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1
326+
* ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a
327+
* whole number of chunks.
328+
* The first {@code numberOfStreams} ranges will be split evenly (in terms of
329+
* number of chunks not the byte size), the final range split
330+
* is for the single final chunk and will be no more than {@code chunkSizeBytes}
331+
* in size. The separate range for the final chunk is because when streaming and
332+
* uploading a large model definition, writing the last part has to handled
333+
* as a special case.
334+
* @param sizeInBytes The total size of the stream
335+
* @param numberOfStreams Divide the bulk of the size into this many streams.
336+
* @param chunkSizeBytes The size of each chunk
337+
* @return List of {@code numberOfStreams} + 1 ranges.
338+
*/
339+
static List<RequestRange> split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) {
340+
int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes);
341+
342+
var ranges = new ArrayList<RequestRange>();
343+
344+
int baseChunksPerStream = numberOfChunks / numberOfStreams;
345+
int remainder = numberOfChunks % numberOfStreams;
346+
long startOffset = 0;
347+
int startChunkIndex = 0;
348+
349+
for (int i = 0; i < numberOfStreams - 1; i++) {
350+
int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream;
351+
long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based
352+
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream));
353+
startOffset = rangeEnd + 1; // range is inclusive start and end
354+
startChunkIndex += numChunksInStream;
355+
}
356+
357+
// Want the final range request to be a single chunk
358+
if (baseChunksPerStream > 1) {
359+
int numChunksExcludingFinal = baseChunksPerStream - 1;
360+
long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1;
361+
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal));
362+
363+
startOffset = rangeEnd + 1;
364+
startChunkIndex += numChunksExcludingFinal;
365+
}
366+
367+
// The final range is a single chunk the end of which should not exceed sizeInBytes
368+
long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1;
369+
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1));
370+
371+
return ranges;
372+
}
235373
}

x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A
7777
String packagedModelId = request.getPackagedModelId();
7878
logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository));
7979

80-
threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME).execute(() -> {
80+
threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME).execute(() -> {
8181
try {
8282
URI uri = ModelLoaderUtils.resolvePackageLocation(repository, packagedModelId + ModelLoaderUtils.METADATA_FILE_EXTENSION);
8383
InputStream inputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri);

0 commit comments

Comments
 (0)