Skip to content

Commit e7e0dff

Browse files
authored
Fix: ml/engine/utils/FileUtils casts long file length to int incorrectly (opensearch-project#3198)
* Use longs when splitting model zip file Signed-off-by: Max Lepikhin <[email protected]> * add test Signed-off-by: Max Lepikhin <[email protected]> * spotless Signed-off-by: Max Lepikhin <[email protected]> * clean up test Signed-off-by: Max Lepikhin <[email protected]> --------- Signed-off-by: Max Lepikhin <[email protected]>
1 parent 5a987bb commit e7e0dff

File tree

2 files changed

+67
-4
lines changed

2 files changed

+67
-4
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ public class FileUtils {
4545
* @throws IOException
4646
*/
4747
public static List<String> splitFileIntoChunks(File file, Path outputPath, int chunkSize) throws IOException {
48-
int fileSize = (int) file.length();
48+
long fileSize = file.length();
4949
ArrayList<String> nameList = new ArrayList<>();
5050
try (InputStream inStream = new BufferedInputStream(new FileInputStream(file))) {
5151
int numberOfChunk = 0;
52-
int totalBytesRead = 0;
52+
long totalBytesRead = 0;
5353
while (totalBytesRead < fileSize) {
5454
String partName = numberOfChunk + "";
55-
int bytesRemaining = fileSize - totalBytesRead;
55+
long bytesRemaining = fileSize - totalBytesRead;
5656
if (bytesRemaining < chunkSize) {
57-
chunkSize = bytesRemaining;
57+
chunkSize = (int) bytesRemaining;
5858
}
5959
byte[] temporary = new byte[chunkSize];
6060
int bytesRead = inStream.read(temporary, 0, chunkSize);
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.utils;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertTrue;
10+
11+
import java.io.File;
12+
import java.nio.file.Files;
13+
import java.nio.file.Path;
14+
import java.util.Arrays;
15+
import java.util.List;
16+
import java.util.Random;
17+
18+
import org.junit.After;
19+
import org.junit.Assert;
20+
import org.junit.Before;
21+
import org.junit.Test;
22+
import org.junit.rules.TemporaryFolder;
23+
24+
public class FileUtilsTest {
25+
private TemporaryFolder tempDir;
26+
27+
@Before
28+
public void setUp() throws Exception {
29+
tempDir = new TemporaryFolder();
30+
tempDir.create();
31+
}
32+
33+
@After
34+
public void tearDown() {
35+
if (tempDir != null) {
36+
tempDir.delete();
37+
}
38+
}
39+
40+
@Test
41+
public void testSplitFileIntoChunks() throws Exception {
42+
// Write file.
43+
Random random = new Random();
44+
File file = tempDir.newFile("large_file");
45+
byte[] data = new byte[1017];
46+
random.nextBytes(data);
47+
Files.write(file.toPath(), data);
48+
49+
// Split file into chunks.
50+
int chunkSize = 325;
51+
List<String> chunkPaths = FileUtils.splitFileIntoChunks(file, tempDir.newFolder().toPath(), chunkSize);
52+
53+
// Verify.
54+
int currentPosition = 0;
55+
for (String chunkPath : chunkPaths) {
56+
byte[] chunk = Files.readAllBytes(Path.of(chunkPath));
57+
assertTrue("Chunk size", currentPosition + chunk.length <= data.length);
58+
Assert.assertArrayEquals(Arrays.copyOfRange(data, currentPosition, currentPosition + chunk.length), chunk);
59+
currentPosition += chunk.length;
60+
}
61+
assertEquals(currentPosition, data.length);
62+
}
63+
}

0 commit comments

Comments
 (0)