diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..514cbe0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,30 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +env: + CARGO_TERM_COLOR: always + +jobs: + fmt: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: cargo fmt --check + + clippy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + - run: cargo clippy --all-targets --all-features -- -D warnings + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + - run: cargo test --all-features diff --git a/Cargo.lock b/Cargo.lock index 8f4424b..b994aea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -243,7 +243,7 @@ dependencies = [ [[package]] name = "ark-srs" -version = "0.3.3" +version = "0.3.4" dependencies = [ "anyhow", "ark-bn254", diff --git a/src/kzg10/aztec20.rs b/src/kzg10/aztec20.rs index 13ecd60..0f1e6d8 100644 --- a/src/kzg10/aztec20.rs +++ b/src/kzg10/aztec20.rs @@ -339,7 +339,8 @@ mod test { while num_leading_zeros < p.coeffs().len() && p.coeffs()[num_leading_zeros].is_zero() { num_leading_zeros += 1; } - let coeffs = ark_std::cfg_iter!(&p.coeffs()[num_leading_zeros..]) + let coeffs = p.coeffs()[num_leading_zeros..] + .iter() .map(|s| s.into_bigint()) .collect::>(); (num_leading_zeros, coeffs) diff --git a/src/load.rs b/src/load.rs index 0dd805f..dc389f7 100644 --- a/src/load.rs +++ b/src/load.rs @@ -1,4 +1,4 @@ -//! Utils for persisting serialized data to files and loading them into memroy. +//! Utils for persisting serialized data to files and loading them into memory. //! We deal with `ark-serialize::CanonicalSerialize` compatible objects. use alloc::{borrow::ToOwned, format, string::String, vec::Vec}; @@ -34,12 +34,52 @@ pub fn load_data(src: PathBuf) -> Result { Ok(T::deserialize_uncompressed_unchecked(&bytes[..])?) } -pub(crate) fn download_url_to_file( - url: &str, - dest: &Path, +pub(crate) struct DownloadConfig { max_retries: usize, base_backoff: Duration, -) -> Result<()> { + connect_timeout: Duration, + read_timeout: Duration, +} + +impl Default for DownloadConfig { + fn default() -> Self { + Self { + max_retries: 5, + base_backoff: Duration::from_secs(1), + connect_timeout: Duration::from_secs(30), + read_timeout: Duration::from_secs(300), + } + } +} + +#[cfg(test)] +impl DownloadConfig { + pub(crate) fn with_max_retries(mut self, max_retries: usize) -> Self { + self.max_retries = max_retries; + self + } + + pub(crate) fn with_base_backoff(mut self, base_backoff: Duration) -> Self { + self.base_backoff = base_backoff; + self + } + + pub(crate) fn with_connect_timeout(mut self, connect_timeout: Duration) -> Self { + self.connect_timeout = connect_timeout; + self + } + + pub(crate) fn with_read_timeout(mut self, read_timeout: Duration) -> Self { + self.read_timeout = read_timeout; + self + } +} + +/// Download `url` into `dest` with retries and concurrent-download deduplication. +/// +/// Uses a `.part` file with an exclusive `flock` so that parallel callers +/// block rather than issuing redundant downloads. +fn download_url_to_file(url: &str, dest: &Path, config: &DownloadConfig) -> Result<()> { create_dir_all(dest.parent().context("no parent dir")?) .context("Unable to create directory")?; @@ -64,27 +104,36 @@ pub(crate) fn download_url_to_file( return Ok(()); } + let agent = ureq::AgentBuilder::new() + .timeout_connect(config.connect_timeout) + .timeout_read(config.read_timeout) + .build(); + let mut last_err = None; - for attempt in 0..=max_retries { - let mut buf: Vec = Vec::new(); - match ureq::get(url).call() { - Ok(resp) => match resp.into_reader().read_to_end(&mut buf) { - Ok(_) if buf.is_empty() => { - last_err = Some(anyhow!("zero-byte response")); - }, - Ok(_) => { + for attempt in 0..=config.max_retries { + match agent.get(url).call() { + Ok(resp) => { + let res = (|| -> Result { part_file.set_len(0)?; - (&part_file).write_all(&buf)?; - fs::rename(&part_path, dest)?; - return Ok(()); - }, - Err(e) => last_err = Some(anyhow::Error::from(e)), + let mut writer = &part_file; + let bytes = std::io::copy(&mut resp.into_reader(), &mut writer) + .context("failed streaming response to .part file")?; + Ok(bytes) + })(); + match res { + Ok(0) => last_err = Some(anyhow!("zero-byte response")), + Ok(_) => { + fs::rename(&part_path, dest)?; + return Ok(()); + }, + Err(e) => last_err = Some(e), + } }, Err(e) => last_err = Some(anyhow::Error::from(e)), } - if attempt < max_retries { - let backoff = base_backoff * 2u32.saturating_pow(attempt as u32); + if attempt < config.max_retries { + let backoff = config.base_backoff * 2u32.saturating_pow(attempt as u32); thread::sleep(backoff); } } @@ -102,7 +151,7 @@ pub fn download_srs_file(basename: &str, dest: impl AsRef) -> Result<()> { "https://github.com/EspressoSystems/ark-srs/releases/download/v{version}/{basename}", ); tracing::info!("Downloading SRS from {url}"); - download_url_to_file(&url, dest.as_ref(), 5, Duration::from_secs(1))?; + download_url_to_file(&url, dest.as_ref(), &DownloadConfig::default())?; tracing::info!("Saved SRS to {:?}", dest.as_ref()); Ok(()) } @@ -204,6 +253,7 @@ pub mod kzg10 { #[cfg(test)] mod tests { use super::*; + use std::net::TcpListener; use std::sync::{Arc, Barrier}; #[test] @@ -220,7 +270,10 @@ mod tests { let dest = dir.path().join("file.bin"); let url = format!("{}/file.bin", server.url()); - download_url_to_file(&url, &dest, 0, Duration::from_millis(10)).unwrap(); + let config = DownloadConfig::default() + .with_max_retries(0) + .with_base_backoff(Duration::from_millis(10)); + download_url_to_file(&url, &dest, &config).unwrap(); assert_eq!(std::fs::read_to_string(&dest).unwrap(), "hello"); mock.assert(); @@ -244,7 +297,10 @@ mod tests { let dest = dir.path().join("file.bin"); let url = format!("{}/file.bin", server.url()); - download_url_to_file(&url, &dest, 2, Duration::from_millis(10)).unwrap(); + let config = DownloadConfig::default() + .with_max_retries(2) + .with_base_backoff(Duration::from_millis(10)); + download_url_to_file(&url, &dest, &config).unwrap(); assert_eq!(std::fs::read_to_string(&dest).unwrap(), "ok"); } @@ -262,7 +318,10 @@ mod tests { let dest = dir.path().join("file.bin"); let url = format!("{}/file.bin", server.url()); - let result = download_url_to_file(&url, &dest, 2, Duration::from_millis(10)); + let config = DownloadConfig::default() + .with_max_retries(2) + .with_base_backoff(Duration::from_millis(10)); + let result = download_url_to_file(&url, &dest, &config); assert!(result.is_err()); assert!(!dest.exists()); @@ -290,7 +349,10 @@ mod tests { let dest = dest.clone(); thread::spawn(move || { barrier.wait(); - download_url_to_file(&url, &dest, 0, Duration::from_millis(10)) + let config = DownloadConfig::default() + .with_max_retries(0) + .with_base_backoff(Duration::from_millis(10)); + download_url_to_file(&url, &dest, &config) }) }) .collect(); @@ -318,7 +380,10 @@ mod tests { let url = format!("{}/file.bin", server.url()); - download_url_to_file(&url, &dest, 0, Duration::from_millis(10)).unwrap(); + let config = DownloadConfig::default() + .with_max_retries(0) + .with_base_backoff(Duration::from_millis(10)); + download_url_to_file(&url, &dest, &config).unwrap(); assert_eq!(std::fs::read_to_string(&dest).unwrap(), "existing"); mock.assert(); @@ -337,7 +402,10 @@ mod tests { let dest = dir.path().join("file.bin"); let url = format!("{}/file.bin", server.url()); - download_url_to_file(&url, &dest, 0, Duration::from_millis(10)).unwrap(); + let config = DownloadConfig::default() + .with_max_retries(0) + .with_base_backoff(Duration::from_millis(10)); + download_url_to_file(&url, &dest, &config).unwrap(); assert_eq!(std::fs::read_to_string(&dest).unwrap(), "data"); let mut part_path = dest.as_os_str().to_owned(); @@ -363,8 +431,49 @@ mod tests { std::fs::write(PathBuf::from(&part_path), "stale data").unwrap(); let url = format!("{}/file.bin", server.url()); - download_url_to_file(&url, &dest, 0, Duration::from_millis(10)).unwrap(); + let config = DownloadConfig::default() + .with_max_retries(0) + .with_base_backoff(Duration::from_millis(10)); + download_url_to_file(&url, &dest, &config).unwrap(); assert_eq!(std::fs::read_to_string(&dest).unwrap(), "fresh"); } + + // non-routable IP trick is slow on macOS (~60s due to OS SYN retransmit) + #[test] + #[cfg_attr(target_os = "macos", ignore)] + fn test_connect_timeout_returns_error() { + let dir = tempfile::tempdir().unwrap(); + let dest = dir.path().join("file.bin"); + let url = "http://10.255.255.1/file.bin"; + + let config = DownloadConfig::default() + .with_connect_timeout(Duration::from_millis(100)) + .with_max_retries(0); + let result = download_url_to_file(url, &dest, &config); + + assert!(result.is_err()); + } + + #[test] + fn test_read_timeout_returns_error() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let _handle = thread::spawn(move || { + let (_stream, _addr) = listener.accept().unwrap(); + thread::sleep(Duration::from_secs(5)); + }); + + let dir = tempfile::tempdir().unwrap(); + let dest = dir.path().join("file.bin"); + let url = format!("http://{}/file.bin", addr); + + let config = DownloadConfig::default() + .with_read_timeout(Duration::from_millis(100)) + .with_max_retries(0); + let result = download_url_to_file(&url, &dest, &config); + + assert!(result.is_err()); + } }