diff --git a/Cargo.toml b/Cargo.toml index 4cd73c1..a1d561a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,9 @@ tracing = { workspace = true } tonic = { workspace = true } tokio = { workspace = true } prost = { workspace = true } +reqwest = "0.12" +tokio-util = { version = "0.7", features = ["rt"] } +url = "2.5.4" [build-dependencies] tonic-build = "0.12.3" diff --git a/src/downloader/config.rs b/src/downloader/config.rs new file mode 100644 index 0000000..9a01ad3 --- /dev/null +++ b/src/downloader/config.rs @@ -0,0 +1,82 @@ +use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; + +#[derive(Debug, Clone)] +pub struct DownloadManagerConfig { + max_concurrent: Arc, + queue_size: usize, +} + +impl Default for DownloadManagerConfig { + fn default() -> Self { + Self { + max_concurrent: Arc::new(AtomicUsize::new(3)), + queue_size: 100, + } + } +} + +impl DownloadManagerConfig { + pub fn queue_size(&self) -> usize { + self.queue_size + } + + pub fn max_concurrent(&self) -> usize { + self.max_concurrent.load(Ordering::Relaxed) + } + + pub fn set_max_concurrent(&self, max: usize) { + self.max_concurrent.store(max, Ordering::Relaxed); + } +} + +#[derive(Debug, Clone)] +pub struct DownloadConfig { + max_retries: usize, + user_agent: Option, + progress_update_interval: Duration, +} + +impl Default for DownloadConfig { + fn default() -> Self { + Self { + max_retries: 3, + user_agent: None, + progress_update_interval: Duration::from_millis(1000), + } + } +} + +impl DownloadConfig { + pub fn with_max_retries(mut self, retries: usize) -> Self { + self.max_retries = retries; + self + } + + pub fn with_user_agent(mut self, user_agent: impl Into) -> Self { + self.user_agent = Some(user_agent.into()); + self + } + + pub fn with_progress_interval(mut self, interval: Duration) -> Self { + self.progress_update_interval = interval; + self + } + + pub fn max_retries(&self) -> usize { + self.max_retries + } + + pub fn user_agent(&self) -> Option<&str> { + self.user_agent.as_deref() + } + + pub fn progress_update_interval(&self) -> Duration { + self.progress_update_interval + } +} diff --git a/src/downloader/handle.rs b/src/downloader/handle.rs new file mode 100644 index 0000000..5e8bc60 --- /dev/null +++ b/src/downloader/handle.rs @@ -0,0 +1,65 @@ +use super::Status; +use crate::Error; +use tokio::{ + fs::File, + sync::{oneshot, watch}, +}; +use tokio_util::sync::CancellationToken; + +#[derive(Debug)] +pub struct DownloadHandle { + result: oneshot::Receiver>, + status: watch::Receiver, + cancel: CancellationToken, +} + +impl DownloadHandle { + pub fn new( + result: oneshot::Receiver>, + status: watch::Receiver, + cancel: CancellationToken, + ) -> Self { + Self { + result, + status, + cancel, + } + } + + pub fn status(&self) -> Status { + *self.status.borrow() + } + + pub fn is_completed(&self) -> bool { + matches!(self.status(), Status::Completed) + } + + pub fn is_cancelled(&self) -> bool { + matches!(self.status(), Status::Cancelled) + } + + pub async fn wait_for_status_update(&mut self) -> Result<(), watch::error::RecvError> { + self.status.changed().await + } + + pub fn cancel(&self) { + self.cancel.cancel(); + } +} + +impl std::future::Future for DownloadHandle { + type Output = Result; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + use std::pin::Pin; + use std::task::Poll; + match Pin::new(&mut self.result).poll(cx) { + Poll::Ready(Ok(result)) => Poll::Ready(result), + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Oneshot(e))), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/src/downloader/manager.rs b/src/downloader/manager.rs new file mode 100644 index 0000000..f009599 --- /dev/null +++ b/src/downloader/manager.rs @@ -0,0 +1,133 @@ +use super::{ + worker::download_thread, DownloadBuilder, DownloadConfig, DownloadManagerConfig, + DownloadRequest, +}; +use crate::{error::DownloadError, Error}; +use reqwest::Client; +use std::{path::Path, sync::Arc}; +use tokio::sync::{mpsc, Semaphore}; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; +use url::Url; + +#[derive(Debug)] +pub struct DownloadManager { + queue: mpsc::Sender, + semaphore: Arc, + cancel: CancellationToken, + config: DownloadManagerConfig, + tracker: TaskTracker, +} + +impl Default for DownloadManager { + fn default() -> Self { + Self::with_config(DownloadManagerConfig::default()) + } +} + +impl DownloadManager { + pub fn with_config(config: DownloadManagerConfig) -> Self { + let (tx, rx) = mpsc::channel(config.queue_size()); + let client = Client::new(); + let semaphore = Arc::new(Semaphore::new(config.max_concurrent())); + let tracker = TaskTracker::new(); + let manager = Self { + queue: tx, + semaphore: semaphore.clone(), + cancel: CancellationToken::new(), + config, + tracker: tracker.clone(), + }; + // Spawn the dispatcher thread to handle download requests + tracker.spawn(dispatcher_thread(client, rx, semaphore, tracker.clone())); + manager + } + + pub fn download(&self, url: Url, destination: impl AsRef) -> DownloadBuilder { + DownloadBuilder::new(self, url, destination) + } + + pub fn download_with_config( + &self, + url: Url, + destination: impl AsRef, + config: DownloadConfig, + ) -> DownloadBuilder { + self.download(url, destination).with_config(config) + } + + pub async fn set_max_parallel_downloads(&self, limit: usize) -> Result<(), Error> { + let current = self.config.max_concurrent(); + if limit > current { + self.semaphore.add_permits(limit - current); + } else if limit < current { + let to_remove = current - limit; + + let permits = self + .semaphore + .acquire_many(to_remove as u32) + .await + .map_err(|_| Error::Download(DownloadError::ManagerShutdown))?; + + permits.forget(); + } + self.config.set_max_concurrent(limit); + + Ok(()) + } + + pub fn cancel_all(&self) { + self.cancel.cancel(); + } + + pub fn queued_downloads(&self) -> usize { + self.queue.max_capacity() - self.queue.capacity() + } + + pub fn active_downloads(&self) -> usize { + // -1 because the dispatcher thread is always running + self.tracker.len() - 1 + } + + pub async fn shutdown(self) -> Result<(), Error> { + self.cancel.cancel(); + self.tracker.close(); + self.tracker.wait().await; + drop(self.queue); + Ok(()) + } + + pub fn is_cancelled(&self) -> bool { + self.cancel.is_cancelled() + } + + pub fn child_token(&self) -> CancellationToken { + self.cancel.child_token() + } + + pub fn queue_request(&self, req: DownloadRequest) -> Result<(), Error> { + self.queue.try_send(req).map_err(|e| match e { + mpsc::error::TrySendError::Full(_) => Error::Download(DownloadError::QueueFull), + mpsc::error::TrySendError::Closed(_) => Error::Download(DownloadError::ManagerShutdown), + }) + } +} + +async fn dispatcher_thread( + client: Client, + mut rx: mpsc::Receiver, + sem: Arc, + tracker: TaskTracker, +) { + while let Some(request) = rx.recv().await { + let permit = match sem.clone().acquire_owned().await { + Ok(permit) => permit, + Err(_) => break, + }; + let client = client.clone(); + tracker.spawn(async move { + // Move the permit into the worker thread so it's automatically released when the thread finishes + let _permit = permit; + download_thread(client.clone(), request).await; + }); + } +} diff --git a/src/downloader/mod.rs b/src/downloader/mod.rs new file mode 100644 index 0000000..51319db --- /dev/null +++ b/src/downloader/mod.rs @@ -0,0 +1,12 @@ +mod config; +mod handle; +mod manager; +mod progress; +mod request; +mod worker; + +pub use config::{DownloadConfig, DownloadManagerConfig}; +pub use handle::DownloadHandle; +pub use manager::DownloadManager; +pub use progress::{DownloadProgress, Status}; +pub use request::{DownloadBuilder, DownloadRequest}; diff --git a/src/downloader/progress.rs b/src/downloader/progress.rs new file mode 100644 index 0000000..e492a10 --- /dev/null +++ b/src/downloader/progress.rs @@ -0,0 +1,124 @@ +use std::time::{Duration, Instant}; + +const SPEED_UPDATE_INTERVAL: Duration = Duration::from_secs(1); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct DownloadProgress { + bytes_downloaded: u64, + total_bytes: Option, + speed_bps: Option, + eta: Option, + + // For calculations + start_time: Instant, + last_update: Instant, + last_speed_update: Instant, + last_bytes_for_speed: u64, +} + +impl std::fmt::Display for DownloadProgress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let percent = self.percent().unwrap_or(0.0); + let speed = self + .speed() + .map(|s| format!("{:.2} B/s", s)) + .unwrap_or("N/A".to_string()); + let eta = self + .eta + .map(|d| format!("{:.2?}", d)) + .unwrap_or("N/A".to_string()); + let elapsed = self.elapsed(); + + write!( + f, + "Downloaded: {} bytes, Total: {:?}, Speed: {}, ETA: {}, Elapsed: {:.2?}, Progress: {:.2}%", + self.bytes_downloaded, + self.total_bytes, + speed, + eta, + elapsed, + percent + ) + } +} + +impl DownloadProgress { + pub fn new(bytes_downloaded: u64, total_bytes: Option) -> Self { + let now = Instant::now(); + Self { + bytes_downloaded, + total_bytes, + speed_bps: None, + eta: None, + + start_time: now, + last_update: now, + last_speed_update: now, + last_bytes_for_speed: bytes_downloaded, + } + } + + pub fn update(&mut self, bytes_downloaded: u64) { + self.bytes_downloaded = bytes_downloaded; + self.update_speed(bytes_downloaded); + self.update_eta(); + } + + fn update_speed(&mut self, bytes_downloaded: u64) { + let now = Instant::now(); + + if now.duration_since(self.last_speed_update) >= SPEED_UPDATE_INTERVAL { + let byte_diff = (bytes_downloaded - self.last_bytes_for_speed) as f64; + let time_diff = now.duration_since(self.last_speed_update).as_secs_f64(); + + self.last_speed_update = now; + self.last_bytes_for_speed = bytes_downloaded; + self.speed_bps = Some((byte_diff / time_diff) as u64); + }; + } + + fn update_eta(&mut self) { + if let (Some(speed), Some(total)) = (self.speed_bps, self.total_bytes) { + if speed > 0 { + let remaining = total.saturating_sub(self.bytes_downloaded); + self.eta = Some(Duration::from_secs(remaining / speed)); + } + } + } + + pub fn percent(&self) -> Option { + self.total_bytes.map(|total| { + if total == 0 { + 0.0 + } else { + (self.bytes_downloaded as f64 / total as f64) * 100.0 + } + }) + } + + pub fn total_bytes(&self) -> Option { + self.total_bytes + } + + pub fn bytes_downloaded(&self) -> u64 { + self.bytes_downloaded + } + + pub fn speed(&self) -> Option { + self.speed_bps + } + + pub fn elapsed(&self) -> Duration { + self.start_time.elapsed() + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Status { + Queued, + InProgress(DownloadProgress), + Retrying, + Completed, + Failed, + Cancelled, +} diff --git a/src/downloader/request.rs b/src/downloader/request.rs new file mode 100644 index 0000000..632e710 --- /dev/null +++ b/src/downloader/request.rs @@ -0,0 +1,161 @@ +use super::{DownloadConfig, DownloadHandle, DownloadManager, DownloadProgress, Status}; +use crate::{error::DownloadError, Error}; +use std::{ + path::{Path, PathBuf}, + time::{Duration, Instant}, +}; +use tokio::{ + fs::File, + sync::{oneshot, watch}, +}; +use tokio_util::sync::CancellationToken; +use url::Url; + +#[derive(Debug)] +pub struct DownloadRequest { + url: Url, + destination: PathBuf, + pub result: oneshot::Sender>, + pub status: watch::Sender, + pub cancel: CancellationToken, + config: DownloadConfig, + + // For rate limiting + last_progress_update: Instant, +} + +impl DownloadRequest { + pub fn new( + url: Url, + destination: impl AsRef, + cancel: CancellationToken, + config: DownloadConfig, + ) -> (Self, DownloadHandle) { + let (result_tx, result_rx) = oneshot::channel(); + let (status_tx, status_rx) = watch::channel(Status::Queued); + ( + Self { + url, + destination: destination.as_ref().to_path_buf(), + result: result_tx, + status: status_tx, + cancel: cancel.clone(), + config, + last_progress_update: Instant::now(), + }, + DownloadHandle::new(result_rx, status_rx, cancel), + ) + } + + pub fn url(&self) -> &Url { + &self.url + } + + pub fn destination(&self) -> &Path { + &self.destination + } + + pub fn config(&self) -> &DownloadConfig { + &self.config + } + + pub fn send_progress(&mut self, progress: DownloadProgress) { + if self.last_progress_update.elapsed() >= self.config.progress_update_interval() { + self.last_progress_update = Instant::now(); + self.status.send(Status::InProgress(progress)).ok(); + } + } + + pub fn fail(self, e: Error) { + let status = if matches!(e, Error::Download(DownloadError::Cancelled)) { + Status::Cancelled + } else { + Status::Failed + }; + self.status.send(status).ok(); + self.result.send(Err(e)).ok(); + } + + pub fn retry(&self) { + self.status.send(Status::Retrying).ok(); + } + + pub fn cancel(self) { + self.status.send(Status::Cancelled).ok(); + self.result + .send(Err(Error::Download(DownloadError::Cancelled))) + .ok(); + } + + pub fn complete(self, file: File) { + self.status.send(Status::Completed).ok(); + self.result.send(Ok(file)).ok(); + } +} + +#[derive(Debug)] +pub struct DownloadBuilder<'a> { + manager: &'a DownloadManager, + url: Url, + destination: PathBuf, + config: DownloadConfig, +} + +impl<'a> DownloadBuilder<'a> { + pub fn new(manager: &'a DownloadManager, url: Url, destination: impl AsRef) -> Self { + Self { + manager, + url, + destination: destination.as_ref().to_path_buf(), + config: DownloadConfig::default(), + } + } + + pub fn with_retries(mut self, retries: usize) -> Self { + self.config = self.config.with_max_retries(retries); + self + } + + pub fn with_user_agent(mut self, user_agent: impl Into) -> Self { + self.config = self.config.with_user_agent(user_agent); + self + } + + /// Set how often progress updates are sent + pub fn with_progress_interval(mut self, interval: Duration) -> Self { + self.config = self.config.with_progress_interval(interval); + self + } + + pub fn with_config(mut self, config: DownloadConfig) -> Self { + self.config = config; + self + } + + pub fn url(&self) -> &Url { + &self.url + } + + pub fn config(&self) -> &DownloadConfig { + &self.config + } + + pub fn start(self) -> Result { + if self.manager.is_cancelled() { + return Err(Error::Download(DownloadError::ManagerShutdown)); + } + + if self.destination.exists() { + return Err(Error::Download(DownloadError::FileExists { + path: self.destination, + })); + } + + let cancel = self.manager.child_token(); + let (req, handle) = DownloadRequest::new(self.url, self.destination, cancel, self.config); + + self.manager.queue_request(req)?; + + Ok(handle) + } +} diff --git a/src/downloader/worker.rs b/src/downloader/worker.rs new file mode 100644 index 0000000..2870734 --- /dev/null +++ b/src/downloader/worker.rs @@ -0,0 +1,120 @@ +use super::{DownloadProgress, DownloadRequest}; +use crate::{error::DownloadError, Error}; +use reqwest::Client; +use std::time::Duration; +use tokio::{fs::File, io::AsyncWriteExt}; + +pub(super) async fn download_thread(client: Client, mut req: DownloadRequest) { + fn should_retry(e: &Error) -> bool { + match e { + Error::Reqwest(network_err) => { + network_err.is_timeout() + || network_err.is_connect() + || network_err.is_request() + || network_err + .status() + .map(|status_code| status_code.is_server_error()) + .unwrap_or(true) + } + Error::Download(DownloadError::Cancelled) | Error::Io(_) => false, + _ => false, + } + } + + let mut last_error = None; + let max_attempts = req.config().max_retries(); + for attempt in 0..=(max_attempts + 1) { + if attempt > max_attempts { + req.fail(Error::Download(DownloadError::RetriesExhausted { + last_error_msg: last_error + .as_ref() + .map(ToString::to_string) + .unwrap_or_else(|| "Unknown Error".to_string()), + })); + return; + } + + if attempt > 0 { + req.retry(); + // Basic exponential backoff + let delay_ms = 1000 * 2u64.pow(attempt as u32 - 1); + let delay = Duration::from_millis(delay_ms); + + tokio::select! { + _ = tokio::time::sleep(delay) => {}, + _ = req.cancel.cancelled() => { + req.cancel(); + return; + } + } + } + + match download(client.clone(), &mut req).await { + Ok(file) => { + req.complete(file); + return; + } + Err(e) => { + if should_retry(&e) { + last_error = Some(e); + continue; + } + + req.fail(e); + return; + } + } + } +} + +async fn download(client: Client, req: &mut DownloadRequest) -> Result { + let mut response = client + .get(req.url().as_ref()) + .send() + .await? + .error_for_status()?; + let total_bytes = response.content_length(); + let mut bytes_downloaded = 0u64; + + // Create the destination directory if it doesn't exist + if let Some(parent) = req.destination().parent() { + tokio::fs::create_dir_all(parent).await?; + } + let mut file = File::create(&req.destination()).await?; + + let mut progress = DownloadProgress::new(bytes_downloaded, total_bytes); + req.send_progress(progress); + + loop { + tokio::select! { + _ = req.cancel.cancelled() => { + drop(file); // Manually drop the file handle to ensure that deletion doesn't fail + tokio::fs::remove_file(&req.destination()).await?; + return Err(Error::Download(DownloadError::Cancelled)); + } + chunk = response.chunk() => { + match chunk { + Ok(Some(chunk)) => { + file.write_all(&chunk).await?; + bytes_downloaded += chunk.len() as u64; + + progress.update(bytes_downloaded); + req.send_progress(progress); + } + Ok(None) => break, + Err(e) => { + drop(file); // Manually drop the file handle to ensure that deletion doesn't fail + tokio::fs::remove_file(&req.destination()).await?; + return Err(Error::Reqwest(e)) + }, + } + } + } + } + + // Ensure the data is written to disk + file.sync_all().await?; + // Open a new file handle with RO permissions + let file = File::options().read(true).open(&req.destination()).await?; + Ok(file) +} diff --git a/src/error.rs b/src/error.rs index 8f897fa..2bbea17 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use thiserror::Error; #[derive(Error, Debug)] @@ -6,4 +7,26 @@ pub enum Error { Io(#[from] std::io::Error), #[error("Serde: {0}")] Serde(#[from] serde_json::Error), + #[error("Reqwest: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("Oneshot: {0}")] + Oneshot(#[from] tokio::sync::oneshot::error::RecvError), + #[error("Download: {0}")] + Download(#[from] DownloadError), +} + +#[derive(Error, Debug, Clone)] +pub enum DownloadError { + #[error("Download was cancelled")] + Cancelled, + #[error("Retry limit exceeded: {last_error_msg}")] + RetriesExhausted { last_error_msg: String }, + #[error("Download queue is full")] + QueueFull, + #[error("Download manager has been shut down")] + ManagerShutdown, + #[error("File already exists: {path}")] + FileExists { path: PathBuf }, + #[error("Invalid URL: {0}")] + InvalidUrl(#[from] url::ParseError), } diff --git a/src/lib.rs b/src/lib.rs index 9af6fcd..5527d4c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ +pub mod downloader; mod error; pub mod runner; + pub use error::Error; pub mod proto { diff --git a/src/runner/gptk.rs b/src/runner/gptk.rs index c824646..571a8b0 100644 --- a/src/runner/gptk.rs +++ b/src/runner/gptk.rs @@ -78,4 +78,12 @@ impl Runner for GPTK { // GPTK requires either x86_64 (Rosetta) or arm64 (Apple Silicon) arch_output == "i386" || arch_output == "arm64" } + + fn info_mut(&mut self) -> &mut RunnerInfo { + &mut self.info + } + + fn initialize(&self, _prefix: impl AsRef) -> Result<(), crate::Error> { + todo!() + } } diff --git a/src/runner/wine.rs b/src/runner/wine.rs index 48fa372..b005b59 100644 --- a/src/runner/wine.rs +++ b/src/runner/wine.rs @@ -17,6 +17,7 @@ pub struct Wine { /// Determines whether a Wine prefix should be configured for 32-bit or 64-bit /// Windows compatibility. This affects which Windows applications can run /// in the prefix +#[allow(dead_code)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PrefixArch { /// 32-bit Windows prefix architecture @@ -30,6 +31,7 @@ pub enum PrefixArch { /// Specifies which version of Windows the Wine prefix should emulate. /// Different applications may require specific Windows versions for /// optimal compatibility. +#[allow(dead_code)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum WindowsVersion { Win7,