Skip to content

Commit cb9d812

Browse files
downloader: Refactor retry logic and improve errors
1 parent 926d2fc commit cb9d812

File tree

2 files changed

+84
-39
lines changed

2 files changed

+84
-39
lines changed

src/downloader.rs

Lines changed: 74 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1+
use crate::{error::DownloadError, Error};
12
use reqwest::{Client, Url};
2-
use std::{path::PathBuf, sync::Arc};
3+
use std::{path::PathBuf, sync::Arc, time::Duration};
34
use tokio::{
45
fs::File,
56
io::AsyncWriteExt,
67
sync::{mpsc, oneshot, watch, Semaphore},
78
};
89

9-
use crate::Error;
10-
1110
const QUEUE_SIZE: usize = 100;
1211
const MAX_RETRIES: usize = 3;
1312

@@ -67,7 +66,6 @@ pub enum Status {
6766
InProgress(DownloadProgress),
6867
Completed,
6968
Retrying,
70-
Cancelled,
7169
Failed,
7270
}
7371

@@ -139,7 +137,7 @@ async fn dispatcher_thread(
139137
mut rx: mpsc::Receiver<DownloadRequest>,
140138
sem: Arc<Semaphore>,
141139
) {
142-
while let Some(mut request) = rx.recv().await {
140+
while let Some(request) = rx.recv().await {
143141
let permit = match sem.clone().acquire_owned().await {
144142
Ok(permit) => permit,
145143
Err(_) => break,
@@ -148,40 +146,81 @@ async fn dispatcher_thread(
148146
tokio::spawn(async move {
149147
// Move the permit into the worker thread so it's automatically released when the thread finishes
150148
let _permit = permit;
151-
loop {
152-
match download_thread(client.clone(), &mut request).await {
153-
Ok(file) => {
154-
let _ = request.status.send(Status::Completed);
155-
let _ = request.result.send(Ok(file));
156-
break;
157-
}
158-
Err(e) => {
159-
if request.remaining_retries > 0 {
160-
let _ = request.status.send(Status::Retrying);
161-
request.remaining_retries -= 1;
162-
} else {
163-
let status = match e {
164-
Error::Io(ref io_err) => {
165-
if io_err.kind() == std::io::ErrorKind::Interrupted {
166-
Status::Cancelled
167-
} else {
168-
Status::Failed
169-
}
170-
}
171-
_ => Status::Failed,
172-
};
173-
let _ = request.status.send(status);
174-
let _ = request.result.send(Err(e));
175-
break;
176-
}
177-
}
149+
download_thread(client.clone(), request).await;
150+
});
151+
}
152+
}
153+
154+
async fn download_thread(client: Client, mut req: DownloadRequest) {
155+
fn should_retry(e: &Error) -> bool {
156+
match e {
157+
Error::Reqwest(network_err) => {
158+
network_err.is_timeout()
159+
|| network_err.is_connect()
160+
|| network_err.is_request()
161+
|| network_err
162+
.status()
163+
.map(|status_code| status_code.is_server_error())
164+
.unwrap_or(true)
165+
}
166+
Error::Download(DownloadError::Cancelled) | Error::Io(_) => false,
167+
_ => false,
168+
}
169+
}
170+
171+
let mut last_error = None;
172+
for attempt in 0..=(MAX_RETRIES) {
173+
if attempt > MAX_RETRIES {
174+
req.status.send(Status::Failed).ok();
175+
req.result
176+
.send(Err(Error::Download(DownloadError::RetriesExhausted {
177+
last_error_msg: last_error
178+
.as_ref()
179+
.map(|e: &crate::Error| e.to_string())
180+
.unwrap_or_else(|| "Unknown Error".to_string()),
181+
})))
182+
.ok();
183+
return;
184+
}
185+
186+
if attempt > 0 {
187+
req.status.send(Status::Retrying).ok();
188+
// Basic exponential backoff
189+
let delay_ms = 1000 * 2u64.pow(attempt as u32 - 1);
190+
let delay = Duration::from_millis(delay_ms);
191+
192+
tokio::select! {
193+
_ = tokio::time::sleep(delay) => {},
194+
_ = &mut req.cancel => {
195+
req.status.send(Status::Failed).ok();
196+
req.result.send(Err(Error::Download(DownloadError::Cancelled))).ok();
197+
return;
178198
}
179199
}
180-
});
200+
}
201+
}
202+
203+
loop {
204+
match download(client.clone(), &mut req).await {
205+
Ok(file) => {
206+
req.status.send(Status::Completed).ok();
207+
req.result.send(Ok(file)).ok();
208+
return;
209+
}
210+
Err(e) => {
211+
if should_retry(&e) {
212+
last_error = Some(e);
213+
continue;
214+
}
215+
req.status.send(Status::Failed).ok();
216+
req.result.send(Err(e)).ok();
217+
return;
218+
}
219+
}
181220
}
182221
}
183222

184-
async fn download_thread(client: Client, req: &mut DownloadRequest) -> Result<File, Error> {
223+
async fn download(client: Client, req: &mut DownloadRequest) -> Result<File, Error> {
185224
let update_progress = |bytes_downloaded: u64, total_bytes: Option<u64>| {
186225
req.status
187226
.send(Status::InProgress(DownloadProgress {
@@ -211,10 +250,7 @@ async fn download_thread(client: Client, req: &mut DownloadRequest) -> Result<Fi
211250
_ = &mut req.cancel => {
212251
drop(file); // Manually drop the file handle to ensure that deletion doesn't fail
213252
tokio::fs::remove_file(&req.destination).await?;
214-
return Err(Error::Io(std::io::Error::new(
215-
std::io::ErrorKind::Interrupted,
216-
"Download cancelled",
217-
)));
253+
return Err(Error::Download(DownloadError::Cancelled));
218254
}
219255
chunk = response.chunk() => {
220256
match chunk {
@@ -230,7 +266,6 @@ async fn download_thread(client: Client, req: &mut DownloadRequest) -> Result<Fi
230266
return Err(Error::Reqwest(e))
231267
},
232268
}
233-
234269
}
235270
}
236271
}

src/error.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,14 @@ pub enum Error {
1010
Reqwest(#[from] reqwest::Error),
1111
#[error("Oneshot: {0}")]
1212
Oneshot(#[from] tokio::sync::oneshot::error::RecvError),
13+
#[error("Download: {0}")]
14+
Download(#[from] DownloadError),
15+
}
16+
17+
#[derive(Error, Debug, Clone)]
18+
pub enum DownloadError {
19+
#[error("Download was cancelled")]
20+
Cancelled,
21+
#[error("Retry limit exceeded")]
22+
RetriesExhausted { last_error_msg: String },
1323
}

0 commit comments

Comments
 (0)