17
17
import org .elasticsearch .common .io .Streams ;
18
18
import org .elasticsearch .common .unit .ByteSizeUnit ;
19
19
import org .elasticsearch .common .unit .ByteSizeValue ;
20
+ import org .elasticsearch .core .Nullable ;
20
21
import org .elasticsearch .core .SuppressForbidden ;
21
22
import org .elasticsearch .rest .RestStatus ;
22
23
import org .elasticsearch .xcontent .XContentParser ;
34
35
import java .security .AccessController ;
35
36
import java .security .MessageDigest ;
36
37
import java .security .PrivilegedAction ;
38
+ import java .util .ArrayList ;
37
39
import java .util .HashMap ;
38
40
import java .util .List ;
39
41
import java .util .Locale ;
40
42
import java .util .Map ;
43
+ import java .util .concurrent .atomic .AtomicInteger ;
44
+ import java .util .concurrent .atomic .AtomicLong ;
41
45
import java .util .stream .Collectors ;
42
46
43
47
import static java .net .HttpURLConnection .HTTP_MOVED_PERM ;
44
48
import static java .net .HttpURLConnection .HTTP_MOVED_TEMP ;
45
49
import static java .net .HttpURLConnection .HTTP_NOT_FOUND ;
46
50
import static java .net .HttpURLConnection .HTTP_OK ;
51
+ import static java .net .HttpURLConnection .HTTP_PARTIAL ;
47
52
import static java .net .HttpURLConnection .HTTP_SEE_OTHER ;
48
53
49
54
/**
@@ -61,6 +66,73 @@ final class ModelLoaderUtils {
61
66
62
67
record VocabularyParts (List <String > vocab , List <String > merges , List <Double > scores ) {}
63
68
69
+ // Range in bytes
70
+ record RequestRange (long rangeStart , long rangeEnd , int startPart , int numParts ) {
71
+ public String bytesRange () {
72
+ return "bytes=" + rangeStart + "-" + rangeEnd ;
73
+ }
74
+ }
75
+
76
+ static class HttpStreamChunker {
77
+
78
+ record BytesAndPartIndex (BytesArray bytes , int partIndex ) {}
79
+
80
+ private final InputStream inputStream ;
81
+ private final int chunkSize ;
82
+ private final AtomicLong totalBytesRead = new AtomicLong ();
83
+ private final AtomicInteger currentPart ;
84
+ private final int lastPartNumber ;
85
+
86
+ HttpStreamChunker (URI uri , RequestRange range , int chunkSize ) {
87
+ var inputStream = getHttpOrHttpsInputStream (uri , range );
88
+ this .inputStream = inputStream ;
89
+ this .chunkSize = chunkSize ;
90
+ this .lastPartNumber = range .startPart () + range .numParts ();
91
+ this .currentPart = new AtomicInteger (range .startPart ());
92
+ }
93
+
94
+ // This ctor exists for testing purposes only.
95
+ HttpStreamChunker (InputStream inputStream , RequestRange range , int chunkSize ) {
96
+ this .inputStream = inputStream ;
97
+ this .chunkSize = chunkSize ;
98
+ this .lastPartNumber = range .startPart () + range .numParts ();
99
+ this .currentPart = new AtomicInteger (range .startPart ());
100
+ }
101
+
102
+ public boolean hasNext () {
103
+ return currentPart .get () < lastPartNumber ;
104
+ }
105
+
106
+ public BytesAndPartIndex next () throws IOException {
107
+ int bytesRead = 0 ;
108
+ byte [] buf = new byte [chunkSize ];
109
+
110
+ while (bytesRead < chunkSize ) {
111
+ int read = inputStream .read (buf , bytesRead , chunkSize - bytesRead );
112
+ // EOF??
113
+ if (read == -1 ) {
114
+ break ;
115
+ }
116
+ bytesRead += read ;
117
+ }
118
+
119
+ if (bytesRead > 0 ) {
120
+ totalBytesRead .addAndGet (bytesRead );
121
+ return new BytesAndPartIndex (new BytesArray (buf , 0 , bytesRead ), currentPart .getAndIncrement ());
122
+ } else {
123
+ return new BytesAndPartIndex (BytesArray .EMPTY , currentPart .get ());
124
+ }
125
+ }
126
+
127
+ public long getTotalBytesRead () {
128
+ return totalBytesRead .get ();
129
+ }
130
+
131
+ public int getCurrentPart () {
132
+ return currentPart .get ();
133
+ }
134
+ }
135
+
64
136
static class InputStreamChunker {
65
137
66
138
private final InputStream inputStream ;
@@ -101,21 +173,26 @@ public int getTotalBytesRead() {
101
173
}
102
174
}
103
175
104
- static InputStream getInputStreamFromModelRepository (URI uri ) throws IOException {
176
+ static InputStream getInputStreamFromModelRepository (URI uri ) {
105
177
String scheme = uri .getScheme ().toLowerCase (Locale .ROOT );
106
178
107
179
// if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository}
108
180
switch (scheme ) {
109
181
case "http" :
110
182
case "https" :
111
- return getHttpOrHttpsInputStream (uri );
183
+ return getHttpOrHttpsInputStream (uri , null );
112
184
case "file" :
113
185
return getFileInputStream (uri );
114
186
default :
115
187
throw new IllegalArgumentException ("unsupported scheme" );
116
188
}
117
189
}
118
190
191
+ static boolean uriIsFile (URI uri ) {
192
+ String scheme = uri .getScheme ().toLowerCase (Locale .ROOT );
193
+ return "file" .equals (scheme );
194
+ }
195
+
119
196
static VocabularyParts loadVocabulary (URI uri ) {
120
197
if (uri .getPath ().endsWith (".json" )) {
121
198
try (InputStream vocabInputStream = getInputStreamFromModelRepository (uri )) {
@@ -174,7 +251,7 @@ private ModelLoaderUtils() {}
174
251
175
252
@ SuppressWarnings ("'java.lang.SecurityManager' is deprecated and marked for removal " )
176
253
@ SuppressForbidden (reason = "we need socket connection to download" )
177
- private static InputStream getHttpOrHttpsInputStream (URI uri ) throws IOException {
254
+ private static InputStream getHttpOrHttpsInputStream (URI uri , @ Nullable RequestRange range ) {
178
255
179
256
assert uri .getUserInfo () == null : "URI's with credentials are not supported" ;
180
257
@@ -186,18 +263,30 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException
186
263
PrivilegedAction <InputStream > privilegedHttpReader = () -> {
187
264
try {
188
265
HttpURLConnection conn = (HttpURLConnection ) uri .toURL ().openConnection ();
266
+ if (range != null ) {
267
+ conn .setRequestProperty ("Range" , range .bytesRange ());
268
+ }
189
269
switch (conn .getResponseCode ()) {
190
270
case HTTP_OK :
271
+ case HTTP_PARTIAL :
191
272
return conn .getInputStream ();
273
+
192
274
case HTTP_MOVED_PERM :
193
275
case HTTP_MOVED_TEMP :
194
276
case HTTP_SEE_OTHER :
195
277
throw new IllegalStateException ("redirects aren't supported yet" );
196
278
case HTTP_NOT_FOUND :
197
279
throw new ResourceNotFoundException ("{} not found" , uri );
280
+ case 416 : // Range not satisfiable, for some reason not in the list of constants
281
+ throw new IllegalStateException ("Invalid request range [" + range .bytesRange () + "]" );
198
282
default :
199
283
int responseCode = conn .getResponseCode ();
200
- throw new ElasticsearchStatusException ("error during downloading {}" , RestStatus .fromCode (responseCode ), uri );
284
+ throw new ElasticsearchStatusException (
285
+ "error during downloading {}. Got response code {}" ,
286
+ RestStatus .fromCode (responseCode ),
287
+ uri ,
288
+ responseCode
289
+ );
201
290
}
202
291
} catch (IOException e ) {
203
292
throw new UncheckedIOException (e );
@@ -209,7 +298,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException
209
298
210
299
@ SuppressWarnings ("'java.lang.SecurityManager' is deprecated and marked for removal " )
211
300
@ SuppressForbidden (reason = "we need load model data from a file" )
212
- private static InputStream getFileInputStream (URI uri ) {
301
+ static InputStream getFileInputStream (URI uri ) {
213
302
214
303
SecurityManager sm = System .getSecurityManager ();
215
304
if (sm != null ) {
@@ -232,4 +321,53 @@ private static InputStream getFileInputStream(URI uri) {
232
321
return AccessController .doPrivileged (privilegedFileReader );
233
322
}
234
323
324
+ /**
325
+ * Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1
326
+ * ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a
327
+ * whole number of chunks.
328
+ * The first {@code numberOfStreams} ranges will be split evenly (in terms of
329
+ * number of chunks not the byte size), the final range split
330
+ * is for the single final chunk and will be no more than {@code chunkSizeBytes}
331
+ * in size. The separate range for the final chunk is because when streaming and
332
+ * uploading a large model definition, writing the last part has to handled
333
+ * as a special case.
334
+ * @param sizeInBytes The total size of the stream
335
+ * @param numberOfStreams Divide the bulk of the size into this many streams.
336
+ * @param chunkSizeBytes The size of each chunk
337
+ * @return List of {@code numberOfStreams} + 1 ranges.
338
+ */
339
+ static List <RequestRange > split (long sizeInBytes , int numberOfStreams , long chunkSizeBytes ) {
340
+ int numberOfChunks = (int ) ((sizeInBytes + chunkSizeBytes - 1 ) / chunkSizeBytes );
341
+
342
+ var ranges = new ArrayList <RequestRange >();
343
+
344
+ int baseChunksPerStream = numberOfChunks / numberOfStreams ;
345
+ int remainder = numberOfChunks % numberOfStreams ;
346
+ long startOffset = 0 ;
347
+ int startChunkIndex = 0 ;
348
+
349
+ for (int i = 0 ; i < numberOfStreams - 1 ; i ++) {
350
+ int numChunksInStream = (i < remainder ) ? baseChunksPerStream + 1 : baseChunksPerStream ;
351
+ long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes ) - 1 ; // range index is 0 based
352
+ ranges .add (new RequestRange (startOffset , rangeEnd , startChunkIndex , numChunksInStream ));
353
+ startOffset = rangeEnd + 1 ; // range is inclusive start and end
354
+ startChunkIndex += numChunksInStream ;
355
+ }
356
+
357
+ // Want the final range request to be a single chunk
358
+ if (baseChunksPerStream > 1 ) {
359
+ int numChunksExcludingFinal = baseChunksPerStream - 1 ;
360
+ long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes ) - 1 ;
361
+ ranges .add (new RequestRange (startOffset , rangeEnd , startChunkIndex , numChunksExcludingFinal ));
362
+
363
+ startOffset = rangeEnd + 1 ;
364
+ startChunkIndex += numChunksExcludingFinal ;
365
+ }
366
+
367
+ // The final range is a single chunk the end of which should not exceed sizeInBytes
368
+ long rangeEnd = Math .min (sizeInBytes , startOffset + (baseChunksPerStream * chunkSizeBytes )) - 1 ;
369
+ ranges .add (new RequestRange (startOffset , rangeEnd , startChunkIndex , 1 ));
370
+
371
+ return ranges ;
372
+ }
235
373
}
0 commit comments