Skip to content

Commit bfe7bb1

Browse files
downloader: Implement a Download Scheduler
- Also Re-Queue requests on retry
1 parent 4540872 commit bfe7bb1

File tree

4 files changed

+363
-249
lines changed

4 files changed

+363
-249
lines changed

crates/download-manager/src/download_manager.rs

Lines changed: 10 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ mod download;
33
mod error;
44
mod events;
55
mod request;
6-
mod worker;
6+
mod scheduler;
77

8-
use crate::{context::Context, request::RequestBuilder, worker::download_thread};
8+
use crate::{
9+
context::Context,
10+
request::RequestBuilder,
11+
scheduler::{Scheduler, SchedulerCmd},
12+
};
913
pub use crate::{
1014
context::DownloadID,
1115
download::{Download, DownloadResult},
@@ -24,7 +28,7 @@ use tokio_stream::wrappers::BroadcastStream;
2428
use tokio_util::{sync::CancellationToken, task::TaskTracker};
2529

2630
pub struct DownloadManager {
27-
queue: mpsc::Sender<Request>,
31+
scheduler_tx: mpsc::Sender<SchedulerCmd>,
2832
ctx: Arc<Context>,
2933
tracker: TaskTracker,
3034
}
@@ -55,22 +59,6 @@ impl DownloadManager {
5559
Request::builder(self)
5660
}
5761

58-
fn queue_request(&self, request: Request) -> Result<(), DownloadError> {
59-
self.queue.try_send(request).map_err(|e| match e {
60-
mpsc::error::TrySendError::Full(_) => DownloadError::QueueFull,
61-
mpsc::error::TrySendError::Closed(_) => DownloadError::ManagerShutdown,
62-
})
63-
}
64-
65-
/// Returns the count of pending requests still buffered in the internal mpsc channel.
66-
///
67-
/// **Note**: This excludes any request already dequeued by the dispatcher but not yet started.
68-
///
69-
/// Consider replacing with an explicit atomic counter.
70-
pub fn queued_downloads(&self) -> usize {
71-
self.queue.max_capacity() - self.queue.capacity()
72-
}
73-
7462
pub fn active_downloads(&self) -> usize {
7563
self.ctx.active.load(Ordering::Relaxed)
7664
}
@@ -129,54 +117,16 @@ impl DownloadManagerBuilder {
129117
let (tx, rx) = mpsc::channel(queue_size);
130118
let ctx = Context::new(max_concurrent);
131119
let tracker = TaskTracker::new();
120+
let scheduler = Scheduler::new(ctx.clone(), tracker.clone(), rx);
132121

133122
let manager = DownloadManager {
134-
queue: tx,
123+
scheduler_tx: tx,
135124
ctx: ctx.clone(),
136125
tracker: tracker.clone(),
137126
};
138127

139-
tracker.spawn(dispatcher_thread(rx, tracker.clone(), ctx));
128+
tracker.spawn(async move { scheduler.run().await });
140129

141130
Ok(manager)
142131
}
143132
}
144-
145-
async fn dispatcher_thread(
146-
mut rx: mpsc::Receiver<Request>,
147-
tracker: TaskTracker,
148-
ctx: Arc<Context>,
149-
) {
150-
struct ActiveGuard {
151-
ctx: Arc<Context>,
152-
_permit: tokio::sync::OwnedSemaphorePermit,
153-
}
154-
155-
impl Drop for ActiveGuard {
156-
fn drop(&mut self) {
157-
self.ctx
158-
.active
159-
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
160-
}
161-
}
162-
163-
while let Some(request) = rx.recv().await {
164-
let guard = match ctx.semaphore.clone().acquire_owned().await {
165-
Ok(p) => {
166-
ctx.active.fetch_add(1, Ordering::Relaxed);
167-
ActiveGuard {
168-
ctx: ctx.clone(),
169-
_permit: p,
170-
}
171-
}
172-
Err(_) => break,
173-
};
174-
175-
let ctx_clone = ctx.clone();
176-
tracker.spawn(async move {
177-
// Move the guard into the worker thread so it's automatically released when the thread finishes
178-
let _guard = guard;
179-
download_thread(request, ctx_clone).await
180-
});
181-
}
182-
}

crates/download-manager/src/request.rs

Lines changed: 33 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
use crate::{
2-
Download, DownloadError, DownloadEvent, DownloadID, DownloadManager, DownloadResult, Progress,
2+
Download, DownloadEvent, DownloadID, DownloadManager, Progress, scheduler::SchedulerCmd,
33
};
44
use derive_builder::Builder;
5-
use reqwest::Url;
6-
use std::{
7-
path::{Path, PathBuf},
8-
time::Duration,
5+
use reqwest::{
6+
Url,
7+
header::{HeaderMap, IntoHeaderName},
98
};
9+
use std::path::{Path, PathBuf};
1010
use tokio::sync::{broadcast, oneshot, watch};
1111
use tokio_util::sync::CancellationToken;
1212

13+
#[derive(Debug, Clone)]
1314
pub struct Request {
1415
id: DownloadID,
1516
url: Url,
@@ -18,28 +19,27 @@ pub struct Request {
1819

1920
progress: watch::Sender<Progress>,
2021
events: broadcast::Sender<DownloadEvent>,
21-
result: oneshot::Sender<Result<DownloadResult, DownloadError>>,
2222

2323
pub cancel_token: CancellationToken,
2424
}
2525

26-
#[derive(Debug, Builder)]
26+
#[derive(Debug, Builder, Clone)]
2727
#[builder(pattern = "owned")]
2828
pub struct DownloadConfig {
2929
#[builder(default = "3")]
3030
retries: u32,
31-
#[builder(default, setter(strip_option))]
32-
user_agent: Option<String>,
3331
#[builder(default = "false")]
3432
overwrite: bool,
33+
#[builder(default = "HeaderMap::new()", setter(skip))]
34+
headers: HeaderMap,
3535
}
3636

3737
impl Default for DownloadConfig {
3838
fn default() -> Self {
3939
DownloadConfig {
4040
retries: 3,
41-
user_agent: None,
4241
overwrite: false,
42+
headers: HeaderMap::new(),
4343
}
4444
}
4545
}
@@ -49,13 +49,13 @@ impl DownloadConfig {
4949
self.retries
5050
}
5151

52-
pub fn user_agent(&self) -> Option<&str> {
53-
self.user_agent.as_deref()
54-
}
55-
5652
pub fn overwrite(&self) -> bool {
5753
self.overwrite
5854
}
55+
56+
pub fn headers(&self) -> &HeaderMap {
57+
&self.headers
58+
}
5959
}
6060

6161
impl Request {
@@ -64,6 +64,7 @@ impl Request {
6464
url: None,
6565
destination: None,
6666
config: DownloadConfigBuilder::default(),
67+
headers: HeaderMap::new(),
6768
manager,
6869
}
6970
}
@@ -87,20 +88,11 @@ impl Request {
8788
&self.config
8889
}
8990

90-
pub fn is_cancelled(&self) -> bool {
91-
self.cancel_token.is_cancelled()
92-
}
93-
94-
fn emit(&self, event: DownloadEvent) {
91+
pub fn emit(&self, event: DownloadEvent) {
9592
// TODO: Log the error
9693
let _ = self.events.send(event);
9794
}
9895

99-
fn send_result(self, result: Result<DownloadResult, DownloadError>) {
100-
// TODO: Log the error
101-
let _ = self.result.send(result);
102-
}
103-
10496
pub fn update_progress(&self, progress: Progress) {
10597
// TODO: Log the error
10698
let _ = self.progress.send(progress);
@@ -110,42 +102,17 @@ impl Request {
110102
self.emit(DownloadEvent::Started {
111103
id: self.id(),
112104
url: self.url().clone(),
113-
destination: self.destination.clone(),
105+
destination: self.destination().to_path_buf(),
114106
total_bytes: None,
115107
});
116108
}
117-
118-
pub fn fail(self, error: DownloadError) {
119-
self.send_result(Err(error));
120-
}
121-
122-
pub fn finish(self, result: DownloadResult) {
123-
self.emit(DownloadEvent::Completed {
124-
id: self.id(),
125-
path: result.path.clone(),
126-
bytes_downloaded: result.bytes_downloaded,
127-
});
128-
self.send_result(Ok(result))
129-
}
130-
131-
pub fn retry(&self, attempt: u32, delay: Duration) {
132-
self.emit(DownloadEvent::Retrying {
133-
id: self.id(),
134-
attempt,
135-
next_delay_ms: delay.as_millis() as u64,
136-
});
137-
}
138-
139-
pub fn cancel(self) {
140-
self.emit(DownloadEvent::Cancelled { id: self.id() });
141-
self.send_result(Err(DownloadError::Cancelled))
142-
}
143109
}
144110

145111
pub struct RequestBuilder<'a> {
146112
url: Option<Url>,
147113
destination: Option<PathBuf>,
148114
config: DownloadConfigBuilder,
115+
headers: HeaderMap,
149116

150117
manager: &'a DownloadManager,
151118
}
@@ -166,47 +133,48 @@ impl RequestBuilder<'_> {
166133
self
167134
}
168135

169-
pub fn user_agent(mut self, user_agent: impl AsRef<str>) -> Self {
170-
self.config = self.config.user_agent(user_agent.as_ref().into());
171-
self
136+
pub fn user_agent(self, user_agent: impl AsRef<str>) -> Self {
137+
self.header(reqwest::header::USER_AGENT, user_agent)
172138
}
173139

174140
pub fn overwrite(mut self, overwrite: bool) -> Self {
175141
self.config = self.config.overwrite(overwrite);
176142
self
177143
}
178144

145+
pub fn header(mut self, header: impl IntoHeaderName, value: impl AsRef<str>) -> Self {
146+
self.headers.insert(header, value.as_ref().parse().unwrap());
147+
self
148+
}
149+
179150
pub fn start(self) -> anyhow::Result<Download> {
180151
let url = self.url.ok_or_else(|| anyhow::anyhow!("URL must be set"))?;
181152
let destination = self
182153
.destination
183154
.ok_or_else(|| anyhow::anyhow!("Destination must be set"))?;
184155
let config = self.config.build()?;
185156

186-
let (progress_tx, progress_rx) = watch::channel(Progress::new(None));
157+
let id = self.manager.ctx.next_id();
187158
let (result_tx, result_rx) = oneshot::channel();
159+
let (progress_tx, progress_rx) = watch::channel(Progress::new(None));
188160
let cancel_token = self.manager.child_token();
189-
190161
let event_tx = self.manager.ctx.events.clone();
191162
let event_rx = event_tx.subscribe();
192-
let id = self.manager.ctx.next_id();
163+
193164
let request = Request {
194165
id,
195166
url: url.clone(),
196167
destination: destination.clone(),
197168
config,
169+
170+
events: event_tx,
198171
progress: progress_tx,
199-
events: event_tx.clone(),
200-
result: result_tx,
201172
cancel_token: cancel_token.clone(),
202173
};
203174

204-
self.manager.queue_request(request)?;
205-
event_tx.send(DownloadEvent::Queued {
206-
id,
207-
url,
208-
destination,
209-
})?;
175+
self.manager
176+
.scheduler_tx
177+
.try_send(SchedulerCmd::Enqueue { request, result_tx });
210178

211179
Ok(Download::new(
212180
id,

0 commit comments

Comments
 (0)