Skip to content

Commit 1bf7bfd

Browse files
downloader: Use tokio::select to check for cancellations
1 parent 395ced4 commit 1bf7bfd

File tree

1 file changed

+26
-16
lines changed

1 file changed

+26
-16
lines changed

src/downloader.rs

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ async fn dispatcher_thread(
139139
// Move the permit into the worker thread so it's automatically released when the thread finishes
140140
let _permit = permit;
141141
loop {
142-
match download_thread(client.clone(), &request).await {
142+
match download_thread(client.clone(), &mut request).await {
143143
Ok(file) => {
144144
let _ = request.status.send(Status::Completed);
145145
let _ = request.result.send(Ok(file));
@@ -171,7 +171,7 @@ async fn dispatcher_thread(
171171
}
172172
}
173173

174-
async fn download_thread(client: Client, req: &DownloadRequest) -> Result<File, Error> {
174+
async fn download_thread(client: Client, req: &mut DownloadRequest) -> Result<File, Error> {
175175
let mut response = client.get(req.url.as_ref()).send().await?;
176176
let total_bytes = response.content_length();
177177
let mut bytes_downloaded = 0u64;
@@ -182,21 +182,31 @@ async fn download_thread(client: Client, req: &DownloadRequest) -> Result<File,
182182
}
183183
let mut file = File::create(&req.destination).await?;
184184

185-
while let Some(chunk) = response.chunk().await.transpose() {
186-
if !req.cancel.is_empty() {
187-
tokio::fs::remove_file(&req.destination).await?;
188-
return Err(Error::Io(std::io::Error::new(
189-
std::io::ErrorKind::Interrupted,
190-
"Download cancelled",
191-
)));
185+
loop {
186+
tokio::select! {
187+
_ = &mut req.cancel => {
188+
tokio::fs::remove_file(&req.destination).await?;
189+
return Err(Error::Io(std::io::Error::new(
190+
std::io::ErrorKind::Interrupted,
191+
"Download cancelled",
192+
)));
193+
}
194+
chunk = response.chunk() => {
195+
match chunk {
196+
Ok(Some(chunk)) => {
197+
file.write_all(&chunk).await?;
198+
bytes_downloaded += chunk.len() as u64;
199+
let _ = req.status.send(Status::InProgress(DownloadProgress {
200+
bytes_downloaded,
201+
total_bytes,
202+
}));
203+
}
204+
Ok(None) => break,
205+
Err(e) => return Err(Error::Reqwest(e)),
206+
}
207+
208+
}
192209
}
193-
let chunk = chunk?;
194-
file.write_all(&chunk).await?;
195-
bytes_downloaded += chunk.len() as u64;
196-
let _ = req.status.send(Status::InProgress(DownloadProgress {
197-
bytes_downloaded,
198-
total_bytes,
199-
}));
200210
}
201211

202212
// Ensure the data is written to disk

0 commit comments

Comments
 (0)