Skip to content

Commit 338d046

Browse files
committed
fix: refactor image download
1 parent a0cab1a commit 338d046

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

xtask/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ clap = { version = "4.4", features = ["derive"] }
1111
colored = "3"
1212
ostool = "0.8"
1313
jkconfig = "0.1"
14-
reqwest = "0.11"
14+
reqwest = "0.12"
1515
schemars = { version = "1", features = ["derive"] }
1616
serde = { version = "1.0", features = ["derive"] }
1717
serde_json = "1"

xtask/src/image.rs

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use std::fs;
2828
use std::env;
2929
use std::io::Read;
3030
// use tokio::fs::File; // Unused import
31-
use tokio::io::AsyncWriteExt;
31+
use tokio::io::{AsyncWriteExt, BufWriter};
3232

3333
/// Base URL for downloading images
3434
const IMAGE_URL_BASE: &str = "https://github.com/arceos-hypervisor/axvisor-guest/releases/download/v0.0.18/";
@@ -345,21 +345,46 @@ async fn image_download(image_name: &str, output_dir: Option<String>, extract: b
345345
println!("Starting download...");
346346

347347
// Use reqwest to download the file
348-
let response = reqwest::get(&download_url).await?;
348+
let mut response = reqwest::get(&download_url).await?;
349349
if !response.status().is_success() {
350350
return Err(anyhow!("Failed to download file: HTTP {}", response.status()));
351351
}
352352

353-
let bytes = response.bytes().await?;
354-
355-
// Write all bytes at once, ensuring we overwrite any existing file
356-
let mut file = tokio::fs::OpenOptions::new()
353+
// Create file with buffered writer for efficient streaming
354+
let file = tokio::fs::OpenOptions::new()
357355
.write(true)
358356
.create(true)
359357
.truncate(true)
360358
.open(&output_path)
361359
.await?;
362-
file.write_all(&bytes).await?;
360+
let mut writer = BufWriter::new(file);
361+
362+
// Get content length for progress reporting (if available)
363+
let content_length = response.content_length();
364+
let mut downloaded = 0u64;
365+
366+
// Stream the response body to file using chunks
367+
while let Some(chunk) = response.chunk().await? {
368+
// Write chunk to file
369+
writer.write_all(&chunk).await
370+
.map_err(|e| anyhow!("Error writing to file: {}", e))?;
371+
372+
// Update progress
373+
downloaded += chunk.len() as u64;
374+
if let Some(total) = content_length {
375+
let percent = (downloaded * 100) / total;
376+
print!("\rDownloading: {}% ({}/{} bytes)", percent, downloaded, total);
377+
} else {
378+
print!("\rDownloaded: {} bytes", downloaded);
379+
}
380+
std::io::Write::flush(&mut std::io::stdout()).unwrap();
381+
}
382+
383+
// Flush the writer to ensure all data is written to disk
384+
writer.flush().await
385+
.map_err(|e| anyhow!("Error flushing file: {}", e))?;
386+
387+
println!("\nDownload completed");
363388

364389

365390
// Verify downloaded file

0 commit comments

Comments
 (0)