Skip to content

Commit 09b47b6

Browse files
downloader: Add DownloadConfig and rate limiting
1 parent 633b752 commit 09b47b6

File tree

6 files changed

+122
-68
lines changed

6 files changed

+122
-68
lines changed

src/downloader/config.rs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
use std::sync::{
2-
atomic::{AtomicUsize, Ordering},
3-
Arc,
1+
use std::{
2+
sync::{
3+
atomic::{AtomicUsize, Ordering},
4+
Arc,
5+
},
6+
time::Duration,
47
};
58

69
#[derive(Debug, Clone)]
@@ -31,3 +34,34 @@ impl DownloadManagerConfig {
3134
self.max_concurrent.store(max, Ordering::Relaxed);
3235
}
3336
}
37+
38+
#[derive(Debug, Clone)]
39+
pub struct DownloadConfig {
40+
max_retries: usize,
41+
user_agent: Option<String>,
42+
progress_update_interval: Duration,
43+
}
44+
45+
impl Default for DownloadConfig {
46+
fn default() -> Self {
47+
Self {
48+
max_retries: 3,
49+
user_agent: None,
50+
progress_update_interval: Duration::from_millis(1000),
51+
}
52+
}
53+
}
54+
55+
impl DownloadConfig {
56+
pub fn max_retries(&self) -> usize {
57+
self.max_retries
58+
}
59+
60+
pub fn user_agent(&self) -> Option<&str> {
61+
self.user_agent.as_deref()
62+
}
63+
64+
pub fn progress_update_interval(&self) -> Duration {
65+
self.progress_update_interval
66+
}
67+
}

src/downloader/manager.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use super::{config::DownloadManagerConfig, download_thread, DownloadHandle, DownloadRequest};
1+
use super::{
2+
config::DownloadManagerConfig, download_thread, DownloadConfig, DownloadHandle, DownloadRequest,
3+
};
24
use crate::{error::DownloadError, Error};
35
use reqwest::{Client, Url};
46
use std::{path::Path, sync::Arc};
@@ -38,10 +40,11 @@ impl DownloadManager {
3840
manager
3941
}
4042

41-
pub fn download(
43+
pub fn download_with_config(
4244
&self,
4345
url: Url,
4446
destination: impl AsRef<Path>,
47+
config: DownloadConfig,
4548
) -> Result<DownloadHandle, Error> {
4649
if self.cancel.is_cancelled() {
4750
return Err(Error::Download(DownloadError::ManagerShutdown));
@@ -55,7 +58,7 @@ impl DownloadManager {
5558
}
5659

5760
let cancel = self.cancel.child_token();
58-
let (req, handle) = DownloadRequest::new_req_handle_pair(url, destination, cancel);
61+
let (req, handle) = DownloadRequest::new(url, destination, cancel, config);
5962

6063
self.queue.try_send(req).map_err(|e| match e {
6164
mpsc::error::TrySendError::Full(_) => Error::Download(DownloadError::QueueFull),
@@ -65,6 +68,14 @@ impl DownloadManager {
6568
Ok(handle)
6669
}
6770

71+
pub fn download(
72+
&self,
73+
url: Url,
74+
destination: impl AsRef<Path>,
75+
) -> Result<DownloadHandle, Error> {
76+
self.download_with_config(url, destination, DownloadConfig::default())
77+
}
78+
6879
pub async fn set_max_parallel_downloads(&self, limit: usize) -> Result<(), Error> {
6980
let current = self.config.max_concurrent();
7081
if limit > current {

src/downloader/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ mod progress;
55
mod types;
66
mod worker;
77

8-
pub use config::DownloadManagerConfig;
8+
pub use config::{DownloadConfig, DownloadManagerConfig};
99
pub use handle::DownloadHandle;
1010
pub use manager::DownloadManager;
1111
pub use progress::DownloadProgress;
12-
pub use types::*;
12+
pub(self) use types::DownloadRequest;
13+
pub use types::Status;
1314
pub(self) use worker::download_thread;

src/downloader/progress.rs

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::time::{Duration, Instant};
22

3+
const SPEED_UPDATE_INTERVAL: Duration = Duration::from_secs(1);
4+
35
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46
pub struct DownloadProgress {
57
bytes_downloaded: u64,
@@ -12,7 +14,6 @@ pub struct DownloadProgress {
1214
last_update: Instant,
1315
last_speed_update: Instant,
1416
last_bytes_for_speed: u64,
15-
update_interval: Duration,
1617
}
1718

1819
impl std::fmt::Display for DownloadProgress {
@@ -42,7 +43,7 @@ impl std::fmt::Display for DownloadProgress {
4243
}
4344

4445
impl DownloadProgress {
45-
pub fn new(bytes_downloaded: u64, total_bytes: Option<u64>, update_interval: Duration) -> Self {
46+
pub fn new(bytes_downloaded: u64, total_bytes: Option<u64>) -> Self {
4647
let now = Instant::now();
4748
Self {
4849
bytes_downloaded,
@@ -54,53 +55,35 @@ impl DownloadProgress {
5455
last_update: now,
5556
last_speed_update: now,
5657
last_bytes_for_speed: bytes_downloaded,
57-
update_interval,
5858
}
5959
}
6060

61-
pub fn update(&self, bytes_downloaded: u64) -> Option<Self> {
62-
fn new_update(
63-
progress: &DownloadProgress,
64-
bytes_downloaded: u64,
65-
instant: Instant,
66-
) -> DownloadProgress {
67-
DownloadProgress {
68-
eta: None,
69-
last_update: instant,
70-
bytes_downloaded,
71-
total_bytes: progress.total_bytes,
72-
speed_bps: progress.speed_bps,
73-
start_time: progress.start_time,
74-
last_speed_update: progress.last_speed_update,
75-
last_bytes_for_speed: progress.last_bytes_for_speed,
76-
update_interval: progress.update_interval,
77-
}
78-
}
61+
pub fn update(&mut self, bytes_downloaded: u64) {
62+
self.bytes_downloaded = bytes_downloaded;
63+
self.update_speed(bytes_downloaded);
64+
self.update_eta();
65+
}
7966

67+
fn update_speed(&mut self, bytes_downloaded: u64) {
8068
let now = Instant::now();
8169

82-
if now.duration_since(self.last_update) < self.update_interval {
83-
return None;
84-
}
85-
let mut new_update = new_update(self, bytes_downloaded, now);
86-
87-
if now.duration_since(self.last_speed_update) >= Duration::from_secs(1) {
70+
if now.duration_since(self.last_speed_update) >= SPEED_UPDATE_INTERVAL {
8871
let byte_diff = (bytes_downloaded - self.last_bytes_for_speed) as f64;
8972
let time_diff = now.duration_since(self.last_speed_update).as_secs_f64();
9073

91-
new_update.last_speed_update = now;
92-
new_update.last_bytes_for_speed = bytes_downloaded;
93-
new_update.speed_bps = Some((byte_diff / time_diff) as u64);
74+
self.last_speed_update = now;
75+
self.last_bytes_for_speed = bytes_downloaded;
76+
self.speed_bps = Some((byte_diff / time_diff) as u64);
9477
};
78+
}
9579

96-
if let (Some(speed), Some(total)) = (new_update.speed_bps, self.total_bytes) {
80+
fn update_eta(&mut self) {
81+
if let (Some(speed), Some(total)) = (self.speed_bps, self.total_bytes) {
9782
if speed > 0 {
98-
let remaining = total.saturating_sub(bytes_downloaded);
99-
new_update.eta = Some(Duration::from_secs(remaining / speed));
83+
let remaining = total.saturating_sub(self.bytes_downloaded);
84+
self.eta = Some(Duration::from_secs(remaining / speed));
10085
}
101-
};
102-
103-
Some(new_update)
86+
}
10487
}
10588

10689
pub fn percent(&self) -> Option<f64> {

src/downloader/types.rs

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use super::{DownloadHandle, DownloadProgress};
2-
use crate::Error;
1+
use super::{DownloadConfig, DownloadHandle, DownloadProgress};
2+
use crate::{error::DownloadError, Error};
33
use reqwest::Url;
4-
use std::path::{Path, PathBuf};
4+
use std::{
5+
path::{Path, PathBuf},
6+
time::Instant,
7+
};
58
use tokio::{
69
fs::File,
710
sync::{oneshot, watch},
@@ -10,33 +13,58 @@ use tokio_util::sync::CancellationToken;
1013

1114
#[derive(Debug)]
1215
pub(crate) struct DownloadRequest {
13-
pub url: Url,
14-
pub destination: PathBuf,
16+
url: Url,
17+
destination: PathBuf,
1518
pub result: oneshot::Sender<Result<File, Error>>,
1619
pub status: watch::Sender<Status>,
1720
pub cancel: CancellationToken,
21+
config: DownloadConfig,
22+
23+
// For rate limiting
24+
last_progress_update: Instant,
1825
}
1926

2027
impl DownloadRequest {
21-
pub fn new_req_handle_pair(
28+
pub fn new(
2229
url: Url,
2330
destination: impl AsRef<Path>,
2431
cancel: CancellationToken,
32+
config: DownloadConfig,
2533
) -> (Self, DownloadHandle) {
2634
let (result_tx, result_rx) = oneshot::channel();
2735
let (status_tx, status_rx) = watch::channel(Status::Queued);
28-
2936
(
3037
Self {
3138
url,
3239
destination: destination.as_ref().to_path_buf(),
3340
result: result_tx,
3441
status: status_tx,
3542
cancel: cancel.clone(),
43+
config,
44+
last_progress_update: Instant::now(),
3645
},
3746
DownloadHandle::new(result_rx, status_rx, cancel),
3847
)
3948
}
49+
50+
pub fn url(&self) -> &Url {
51+
&self.url
52+
}
53+
54+
pub fn destination(&self) -> &Path {
55+
&self.destination
56+
}
57+
58+
pub fn config(&self) -> &DownloadConfig {
59+
&self.config
60+
}
61+
62+
pub fn send_progress(&mut self, progress: DownloadProgress) {
63+
if self.last_progress_update.elapsed() >= self.config.progress_update_interval() {
64+
self.last_progress_update = Instant::now();
65+
self.status.send(Status::InProgress(progress)).ok();
66+
}
67+
}
4068
}
4169

4270
#[derive(Debug, Copy, Clone, PartialEq, Eq)]

src/downloader/worker.rs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ use reqwest::Client;
44
use std::time::Duration;
55
use tokio::{fs::File, io::AsyncWriteExt};
66

7-
const MAX_RETRIES: usize = 3;
8-
97
pub(super) async fn download_thread(client: Client, mut req: DownloadRequest) {
108
fn should_retry(e: &Error) -> bool {
119
match e {
@@ -24,8 +22,9 @@ pub(super) async fn download_thread(client: Client, mut req: DownloadRequest) {
2422
}
2523

2624
let mut last_error = None;
27-
for attempt in 0..=(MAX_RETRIES + 1) {
28-
if attempt > MAX_RETRIES {
25+
let max_attempts = req.config().max_retries();
26+
for attempt in 0..=(max_attempts + 1) {
27+
if attempt > max_attempts {
2928
req.status.send(Status::Failed).ok();
3029
req.result
3130
.send(Err(Error::Download(DownloadError::RetriesExhausted {
@@ -81,44 +80,42 @@ pub(super) async fn download_thread(client: Client, mut req: DownloadRequest) {
8180

8281
async fn download(client: Client, req: &mut DownloadRequest) -> Result<File, Error> {
8382
let mut response = client
84-
.get(req.url.as_ref())
83+
.get(req.url().as_ref())
8584
.send()
8685
.await?
8786
.error_for_status()?;
8887
let total_bytes = response.content_length();
8988
let mut bytes_downloaded = 0u64;
9089

9190
// Create the destination directory if it doesn't exist
92-
if let Some(parent) = req.destination.parent() {
91+
if let Some(parent) = req.destination().parent() {
9392
tokio::fs::create_dir_all(parent).await?;
9493
}
95-
let mut file = File::create(&req.destination).await?;
94+
let mut file = File::create(&req.destination()).await?;
9695

97-
let update_interval = Duration::from_millis(250);
98-
let mut progress = DownloadProgress::new(bytes_downloaded, total_bytes, update_interval);
99-
req.status.send(Status::InProgress(progress)).ok();
96+
let mut progress = DownloadProgress::new(bytes_downloaded, total_bytes);
97+
req.send_progress(progress);
10098

10199
loop {
102100
tokio::select! {
103101
_ = req.cancel.cancelled() => {
104102
drop(file); // Manually drop the file handle to ensure that deletion doesn't fail
105-
tokio::fs::remove_file(&req.destination).await?;
103+
tokio::fs::remove_file(&req.destination()).await?;
106104
return Err(Error::Download(DownloadError::Cancelled));
107105
}
108106
chunk = response.chunk() => {
109107
match chunk {
110108
Ok(Some(chunk)) => {
111109
file.write_all(&chunk).await?;
112110
bytes_downloaded += chunk.len() as u64;
113-
if let Some(new_progress) = progress.update(bytes_downloaded) {
114-
progress = new_progress;
115-
}
116-
req.status.send(Status::InProgress(progress)).ok();
111+
112+
progress.update(bytes_downloaded);
113+
req.send_progress(progress);
117114
}
118115
Ok(None) => break,
119116
Err(e) => {
120117
drop(file); // Manually drop the file handle to ensure that deletion doesn't fail
121-
tokio::fs::remove_file(&req.destination).await?;
118+
tokio::fs::remove_file(&req.destination()).await?;
122119
return Err(Error::Reqwest(e))
123120
},
124121
}
@@ -129,6 +126,6 @@ async fn download(client: Client, req: &mut DownloadRequest) -> Result<File, Err
129126
// Ensure the data is written to disk
130127
file.sync_all().await?;
131128
// Open a new file handle with RO permissions
132-
let file = File::options().read(true).open(&req.destination).await?;
129+
let file = File::options().read(true).open(&req.destination()).await?;
133130
Ok(file)
134131
}

0 commit comments

Comments
 (0)