Skip to content

Commit 748106a

Browse files
downloader: Implement DownloadManager
1 parent dbbe919 commit 748106a

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

src/downloader.rs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
use reqwest::{Client, Url};
2+
use std::{fs::File, io::Write, path::PathBuf, sync::Arc};
3+
use tokio::sync::{broadcast, mpsc, oneshot, watch, Semaphore};
4+
5+
const QUEUE_SIZE: usize = 100;
6+
7+
#[derive(Debug)]
8+
struct DownloadRequest {
9+
url: Url,
10+
destination: PathBuf,
11+
result: oneshot::Sender<Result<File, reqwest::Error>>,
12+
status: watch::Sender<Status>,
13+
progress: broadcast::Sender<DownloadProgress>,
14+
}
15+
16+
#[derive(Debug, Clone, Copy)]
17+
pub struct DownloadProgress {
18+
pub bytes_downloaded: u64,
19+
pub total_bytes: Option<u64>,
20+
}
21+
22+
#[derive(Debug)]
23+
pub struct DownloadHandle {
24+
result: oneshot::Receiver<Result<File, reqwest::Error>>,
25+
status: watch::Receiver<Status>,
26+
progress: broadcast::Receiver<DownloadProgress>,
27+
}
28+
29+
impl DownloadHandle {
30+
pub async fn r#await(self) -> Result<std::fs::File, reqwest::Error> {
31+
match self.result.await {
32+
Ok(result) => result,
33+
Err(_) => todo!(),
34+
}
35+
}
36+
37+
pub fn status(&self) -> Status {
38+
self.status.borrow().clone()
39+
}
40+
41+
pub fn subscribe_progress(&self) -> &broadcast::Receiver<DownloadProgress> {
42+
&self.progress
43+
}
44+
}
45+
46+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
47+
pub enum Status {
48+
Pending,
49+
InProgress,
50+
Completed,
51+
Failed,
52+
}
53+
54+
#[derive(Debug)]
55+
pub struct DownloadManager {
56+
queue: mpsc::Sender<DownloadRequest>,
57+
semaphore: Arc<Semaphore>,
58+
}
59+
60+
impl Drop for DownloadManager {
61+
fn drop(&mut self) {
62+
// Need to manually close the semaphore to make sure dispatcher_thread stops waiting for permits
63+
self.semaphore.close();
64+
}
65+
}
66+
67+
impl DownloadManager {
68+
pub fn new(limit: usize) -> Self {
69+
let (tx, rx) = mpsc::channel(QUEUE_SIZE);
70+
let client = Client::new();
71+
let semaphore = Arc::new(Semaphore::new(limit));
72+
let manager = Self {
73+
queue: tx,
74+
semaphore: semaphore.clone(),
75+
};
76+
// Spawn the dispatcher thread to handle download requests
77+
tokio::spawn(async move { dispatcher_thread(client, rx, semaphore).await });
78+
manager
79+
}
80+
81+
pub fn set_max_parallel_downloads(&self, limit: usize) {
82+
let current = self.semaphore.available_permits();
83+
if limit > current {
84+
self.semaphore.add_permits(limit - current);
85+
} else if limit < current {
86+
let to_remove = current - limit;
87+
for _ in 0..to_remove {
88+
let _ = self.semaphore.try_acquire();
89+
}
90+
}
91+
}
92+
93+
pub fn add_request(&self, url: Url, destination: PathBuf) -> DownloadHandle {
94+
let (result_tx, result_rx) = oneshot::channel();
95+
let (status_tx, status_rx) = watch::channel(Status::Pending);
96+
let (progress_tx, progress_rx) = broadcast::channel(16);
97+
98+
let req = DownloadRequest {
99+
url,
100+
destination,
101+
result: result_tx,
102+
status: status_tx,
103+
progress: progress_tx,
104+
};
105+
106+
let _ = self.queue.try_send(req);
107+
108+
DownloadHandle {
109+
result: result_rx,
110+
status: status_rx,
111+
progress: progress_rx,
112+
}
113+
}
114+
}
115+
116+
async fn dispatcher_thread(
117+
client: Client,
118+
mut rx: mpsc::Receiver<DownloadRequest>,
119+
sem: Arc<Semaphore>,
120+
) {
121+
while let Some(request) = rx.recv().await {
122+
let permit = match sem.clone().acquire_owned().await {
123+
Ok(permit) => permit,
124+
Err(_) => break,
125+
};
126+
let client = client.clone();
127+
tokio::spawn(async move {
128+
// Move the permit into the worker thread so it's automatically released when the thread finishes
129+
let _permit = permit;
130+
let _ = download_thread(client, request).await;
131+
});
132+
}
133+
}
134+
135+
async fn download_thread(
136+
client: Client,
137+
req: DownloadRequest,
138+
) -> Result<(), Box<dyn std::error::Error>> {
139+
let mut resp = client.get(req.url).send().await?;
140+
let total = resp.content_length();
141+
let mut file = File::create(&req.destination)?;
142+
// let mut stream = resp.bytes().await?;
143+
let mut downloaded = 0u64;
144+
while let Some(chunk) = resp.chunk().await.transpose() {
145+
let chunk = chunk?;
146+
file.write_all(&chunk)?;
147+
downloaded += chunk.len() as u64;
148+
let _ = req.progress.send(DownloadProgress {
149+
bytes_downloaded: downloaded,
150+
total_bytes: total,
151+
});
152+
}
153+
Ok(())
154+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
pub mod downloader;
12
mod error;
23
pub mod runner;
4+
35
pub use error::Error;
46

57
pub mod proto {

0 commit comments

Comments
 (0)