Skip to content

Commit 5dba39c

Browse files
committed
ADD: Add Rust batch checksum, retry, and resume
1 parent 1e0edad commit 5dba39c

File tree

3 files changed

+151
-26
lines changed

3 files changed

+151
-26
lines changed

CHANGELOG.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
# Changelog
22

3-
## 0.33.1 - TBD
3+
## 0.34.0 - TBD
4+
5+
### Enhancements
6+
- Added batch download retry, resumption, and checksum verification
7+
- Changed setter for `batch::DownloadParams` to accept any `impl ToString` for
8+
`filename_to_download`
9+
10+
### Breaking changes
11+
- Changed `sha2` and `hex` to required dependencies
12+
13+
## 0.33.1 - 2025-08-26
414

515
### Enhancements
616
- Upgraded DBN version to 0.41.0:

Cargo.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,21 @@ historical = [
2828
"dep:serde_json",
2929
"tokio/fs"
3030
]
31-
live = ["dep:hex", "dep:sha2", "tokio/net"]
31+
live = ["tokio/net"]
3232

3333
[dependencies]
3434
dbn = { version = "0.41.0", features = ["async", "serde"] }
3535

3636
async-compression = { version = "0.4", features = ["tokio", "zstd"], optional = true }
3737
# Async stream trait
3838
futures = { version = "0.3", optional = true }
39-
# Used for Live authentication
40-
hex = { version = "0.4", optional = true }
39+
# Used for Live authentication and historical checksums
40+
hex = "0.4"
4141
reqwest = { version = "0.12", optional = true, features = ["json", "stream"], default-features = false }
4242
serde = { version = "1.0", optional = true, features = ["derive"] }
4343
serde_json = { version = "1.0", optional = true }
44-
# Used for Live authentication
45-
sha2 = { version = "0.10", optional = true }
44+
# Used for Live authentication and historical checksums
45+
sha2 = "0.10"
4646
thiserror = "2.0"
4747
time = { version = ">=0.3.35", features = ["macros", "parsing", "serde"] }
4848
tokio = { version = ">=1.38", features = ["io-util", "macros"] }

src/historical/batch.rs

Lines changed: 135 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,24 @@
22
33
use core::fmt;
44
use std::{
5+
cmp::Ordering,
56
collections::HashMap,
67
fmt::Write,
78
num::NonZeroU64,
9+
os::unix::fs::MetadataExt,
810
path::{Path, PathBuf},
911
str::FromStr,
1012
};
1113

1214
use dbn::{Compression, Encoding, SType, Schema};
1315
use futures::StreamExt;
16+
use hex::ToHex;
1417
use reqwest::RequestBuilder;
1518
use serde::{de, Deserialize, Deserializer};
19+
use sha2::{Digest, Sha256};
1620
use time::OffsetDateTime;
1721
use tokio::io::BufWriter;
18-
use tracing::info;
22+
use tracing::{debug, error, info, info_span, warn, Instrument};
1923
use typed_builder::TypedBuilder;
2024

2125
use crate::{historical::check_http_error, Error, Symbols};
@@ -144,10 +148,11 @@ impl BatchClient<'_> {
144148
.urls
145149
.get("https")
146150
.ok_or_else(|| Error::internal("Missing https URL for batch file"))?;
147-
self.download_file(https_url, &output_path).await?;
151+
self.download_file(https_url, &output_path, &file_desc.hash, file_desc.size)
152+
.await?;
148153
Ok(vec![output_path])
149154
} else {
150-
let mut paths = Vec::new();
155+
let mut paths = Vec::with_capacity(job_files.len());
151156
for file_desc in job_files.iter() {
152157
let output_path = params
153158
.output_dir
@@ -157,31 +162,136 @@ impl BatchClient<'_> {
157162
.urls
158163
.get("https")
159164
.ok_or_else(|| Error::internal("Missing https URL for batch file"))?;
160-
self.download_file(https_url, &output_path).await?;
165+
self.download_file(https_url, &output_path, &file_desc.hash, file_desc.size)
166+
.await?;
161167
paths.push(output_path);
162168
}
163169
Ok(paths)
164170
}
165171
}
166172

167-
async fn download_file(&mut self, url: &str, path: impl AsRef<Path>) -> crate::Result<()> {
173+
async fn download_file(
174+
&mut self,
175+
url: &str,
176+
path: &Path,
177+
hash: &str,
178+
exp_size: u64,
179+
) -> crate::Result<()> {
180+
const MAX_RETRIES: usize = 5;
168181
let url = reqwest::Url::parse(url)
169182
.map_err(|e| Error::internal(format!("Unable to parse URL: {e:?}")))?;
170-
let resp = self.inner.get_with_path(url.path())?.send().await?;
171-
let mut stream = check_http_error(resp).await?.bytes_stream();
172-
info!(%url, path=%path.as_ref().display(), "Downloading file");
173-
let mut output = BufWriter::new(
174-
tokio::fs::OpenOptions::new()
175-
.create(true)
176-
.truncate(true)
177-
.write(true)
178-
.open(path)
179-
.await?,
180-
);
181-
while let Some(chunk) = stream.next().await {
182-
tokio::io::copy(&mut chunk?.as_ref(), &mut output).await?;
183+
184+
let Some((hash_algo, exp_hash_hex)) = hash.split_once(':') else {
185+
return Err(Error::internal("Unexpected hash string format {hash:?}"));
186+
};
187+
let mut hasher = if hash_algo == "sha256" {
188+
Some(Sha256::new())
189+
} else {
190+
warn!(
191+
hash_algo,
192+
"Skipping checksum with unsupported hash algorithm"
193+
);
194+
None
195+
};
196+
197+
let span = info_span!("BatchDownload", %url, path=%path.display());
198+
async move {
199+
let mut retries = 0;
200+
'retry: loop {
201+
let mut req = self.inner.get_with_path(url.path())?;
202+
match Self::check_if_exists(path, exp_size).await? {
203+
Header::Skip => {
204+
return Ok(());
205+
}
206+
Header::Range(Some((key, val))) => {
207+
req = req.header(key, val);
208+
}
209+
Header::Range(None) => {}
210+
}
211+
let resp = req.send().await?;
212+
let mut stream = check_http_error(resp).await?.bytes_stream();
213+
info!("Downloading file");
214+
let mut output = BufWriter::new(
215+
tokio::fs::OpenOptions::new()
216+
.create(true)
217+
.append(true)
218+
.write(true)
219+
.open(path)
220+
.await?,
221+
);
222+
while let Some(chunk) = stream.next().await {
223+
let chunk = match chunk {
224+
Ok(chunk) => chunk,
225+
Err(err) if retries < MAX_RETRIES => {
226+
retries += 1;
227+
error!(?err, retries, "Retrying download");
228+
continue 'retry;
229+
}
230+
Err(err) => {
231+
return Err(crate::Error::from(err));
232+
}
233+
};
234+
if retries > 0 {
235+
retries = 0;
236+
info!("Resumed download");
237+
}
238+
if let Some(hasher) = hasher.as_mut() {
239+
hasher.update(&chunk)
240+
}
241+
tokio::io::copy(&mut chunk.as_ref(), &mut output).await?;
242+
}
243+
debug!("Completed download");
244+
Self::verify_hash(hasher, exp_hash_hex).await;
245+
return Ok(());
246+
}
247+
}
248+
.instrument(span)
249+
.await
250+
}
251+
252+
async fn check_if_exists(path: &Path, exp_size: u64) -> crate::Result<Header> {
253+
let Ok(metadata) = tokio::fs::metadata(path).await else {
254+
return Ok(Header::Range(None));
255+
};
256+
let actual_size = metadata.size();
257+
match actual_size.cmp(&exp_size) {
258+
Ordering::Less => {
259+
debug!(
260+
prev_downloaded_bytes = actual_size,
261+
total_bytes = exp_size,
262+
"Found existing file, resuming download"
263+
);
264+
}
265+
Ordering::Equal => {
266+
debug!("Skipping download as file already exists and matches expected size");
267+
return Ok(Header::Skip);
268+
}
269+
Ordering::Greater => {
270+
return Err(crate::Error::Io(std::io::Error::other(format!(
271+
"Batch file {} already exists with size {actual_size} which is larger than expected size {exp_size}",
272+
path.file_name().unwrap().display(),
273+
))));
274+
}
275+
}
276+
Ok(Header::Range(Some((
277+
"Range",
278+
format!("bytes={}-", metadata.size()),
279+
))))
280+
}
281+
282+
async fn verify_hash(hasher: Option<Sha256>, exp_hash_hex: &str) {
283+
let Some(hasher) = hasher else {
284+
return;
285+
};
286+
let hash_hex = hasher.finalize().encode_hex::<String>();
287+
if hash_hex != exp_hash_hex {
288+
warn!(
289+
hash_hex,
290+
exp_hash_hex, "Downloaded file failed checksum validation"
291+
);
292+
} else {
293+
debug!("Successfully verified checksum");
183294
}
184-
Ok(())
185295
}
186296

187297
const PATH_PREFIX: &'static str = "batch";
@@ -403,7 +513,7 @@ pub struct DownloadParams {
403513
#[builder(setter(transform = |dt: impl ToString| dt.to_string()))]
404514
pub job_id: String,
405515
/// `None` means all files associated with the job will be downloaded.
406-
#[builder(default, setter(strip_option))]
516+
#[builder(default, setter(transform = |filename: impl ToString| Some(filename.to_string())))]
407517
pub filename_to_download: Option<String>,
408518
}
409519

@@ -542,6 +652,11 @@ fn deserialize_compression<'de, D: serde::Deserializer<'de>>(
542652
Ok(opt.unwrap_or(Compression::None))
543653
}
544654

655+
enum Header {
656+
Skip,
657+
Range(Option<(&'static str, String)>),
658+
}
659+
545660
#[cfg(test)]
546661
mod tests {
547662
use reqwest::StatusCode;

0 commit comments

Comments
 (0)