Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/126009.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 126009
summary: Change ModelLoaderUtils.split to return the correct number of chunks and ranges.
area: Machine Learning
type: bug
issues:
- 121799
Original file line number Diff line number Diff line change
Expand Up @@ -336,50 +336,44 @@ static InputStream getFileInputStream(URI uri) {
* Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1
* ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a
* whole number of chunks.
* The first {@code numberOfStreams} ranges will be split evenly (in terms of
* number of chunks not the byte size), the final range split
* All ranges except the final range will be split approximately evenly
* (in terms of number of chunks not the byte size), the final range split
* is for the single final chunk and will be no more than {@code chunkSizeBytes}
* in size. The separate range for the final chunk is because when streaming and
* uploading a large model definition, writing the last part has to handled
* as a special case.
* Less ranges may be returned in case the stream size is too small.
* Fewer ranges may be returned in case the stream size is too small.
* @param sizeInBytes The total size of the stream
* @param numberOfStreams Divide the bulk of the size into this many streams.
* @param chunkSizeBytes The size of each chunk
* @return List of {@code numberOfStreams} + 1 ranges.
* @return List of {@code numberOfStreams} + 1 or fewer ranges.
*/
static List<RequestRange> split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) {
int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes);

int numberOfRanges = numberOfStreams + 1;
if (numberOfStreams > numberOfChunks) {
numberOfStreams = numberOfChunks;
numberOfRanges = numberOfChunks;
}
var ranges = new ArrayList<RequestRange>();

int baseChunksPerStream = numberOfChunks / numberOfStreams;
int remainder = numberOfChunks % numberOfStreams;
int baseChunksPerRange = (numberOfChunks - 1) / (numberOfRanges - 1);
int remainder = (numberOfChunks - 1) % (numberOfRanges - 1);
long startOffset = 0;
int startChunkIndex = 0;

for (int i = 0; i < numberOfStreams - 1; i++) {
int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream;
long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream));
startOffset = rangeEnd + 1; // range is inclusive start and end
startChunkIndex += numChunksInStream;
}
for (int i = 0; i < numberOfRanges - 1; i++) {
int numChunksInRange = (i < remainder) ? baseChunksPerRange + 1 : baseChunksPerRange;

// Want the final range request to be a single chunk
if (baseChunksPerStream > 1) {
int numChunksExcludingFinal = baseChunksPerStream - 1;
long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal));
long rangeEnd = startOffset + (((long) numChunksInRange) * chunkSizeBytes) - 1; // range index is 0 based

startOffset = rangeEnd + 1;
startChunkIndex += numChunksExcludingFinal;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInRange));
startOffset = rangeEnd + 1; // range is inclusive start and end
startChunkIndex += numChunksInRange;
}

// The final range is a single chunk the end of which should not exceed sizeInBytes
long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1;
long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerRange * chunkSizeBytes)) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1));

return ranges;
Expand Down
Loading