Skip to content

Commit a1e9644

Browse files
committed
fix for edge condition with connections smaller than total size
1 parent 2bf0c62 commit a1e9644

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

src/utils/net/download_file.rs

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,10 @@ async fn download_http_parallel(
282282
.await
283283
.context("couldn't allocate file space")?;
284284

285-
let chunk_size = total_size / num_connections as u64;
285+
// Prevent underflow when file is smaller than connection count
286+
// Use at most as many connections as there are bytes
287+
let effective_connections = (num_connections as u64).min(total_size.max(1));
288+
let chunk_size = total_size / effective_connections;
286289

287290
// Progress tracking - log every 5 seconds like the forest::progress system
288291
let bytes_downloaded = Arc::new(std::sync::atomic::AtomicU64::new(0));
@@ -292,7 +295,7 @@ async fn download_http_parallel(
292295
const UPDATE_FREQUENCY: Duration = Duration::from_secs(5);
293296

294297
// Download chunks in parallel
295-
let download_tasks = (0..num_connections).map(|i| {
298+
let download_tasks = (0..effective_connections).map(|i| {
296299
let client = client.clone();
297300
let url = url.clone();
298301
let tmp_path = tmp_dst_path.clone();
@@ -301,11 +304,11 @@ async fn download_http_parallel(
301304
let last_logged_time = Arc::clone(&last_logged_time);
302305
let callback = callback.clone();
303306

304-
let start = i as u64 * chunk_size;
305-
let end = if i == num_connections - 1 {
307+
let start = i * chunk_size;
308+
let end = if i == effective_connections - 1 {
306309
total_size - 1
307310
} else {
308-
(i + 1) as u64 * chunk_size - 1
311+
((i + 1) * chunk_size - 1).min(total_size - 1)
309312
};
310313

311314
async move {
@@ -405,7 +408,7 @@ async fn download_http_parallel(
405408

406409
// Execute all downloads in parallel and collect results
407410
let results: Vec<_> = stream::iter(download_tasks)
408-
.buffer_unordered(num_connections)
411+
.buffer_unordered(effective_connections as usize)
409412
.collect()
410413
.await;
411414

@@ -885,4 +888,25 @@ mod test {
885888
let content = std::fs::read(&result).unwrap();
886889
assert_eq!(content, TEST_FILE_CONTENT);
887890
}
891+
892+
#[tokio::test]
893+
async fn test_small_file_with_many_connections() {
894+
// Test edge case: file smaller than connection count
895+
// This tests the underflow prevention when chunk_size would be 0
896+
let small_content: &[u8] = b"Hi!"; // 3 bytes
897+
let server = TestServer::start_with_content(small_content).await;
898+
let temp_dir = tempfile::tempdir().unwrap();
899+
let url = server.url("/test-file");
900+
901+
// Try to download with more connections than bytes
902+
let result = download_http_parallel(&url, temp_dir.path(), "tiny.dat", 5, None)
903+
.await
904+
.unwrap();
905+
906+
assert!(result.exists());
907+
908+
// Verify content is correct
909+
let downloaded = std::fs::read(&result).unwrap();
910+
assert_eq!(downloaded, small_content);
911+
}
888912
}

0 commit comments

Comments
 (0)