22
33use core:: fmt;
44use std:: {
5+ cmp:: Ordering ,
56 collections:: HashMap ,
67 fmt:: Write ,
78 num:: NonZeroU64 ,
9+ os:: unix:: fs:: MetadataExt ,
810 path:: { Path , PathBuf } ,
911 str:: FromStr ,
1012} ;
1113
1214use dbn:: { Compression , Encoding , SType , Schema } ;
1315use futures:: StreamExt ;
16+ use hex:: ToHex ;
1417use reqwest:: RequestBuilder ;
1518use serde:: { de, Deserialize , Deserializer } ;
19+ use sha2:: { Digest , Sha256 } ;
1620use time:: OffsetDateTime ;
1721use tokio:: io:: BufWriter ;
18- use tracing:: info;
22+ use tracing:: { debug , error , info, info_span , warn , Instrument } ;
1923use typed_builder:: TypedBuilder ;
2024
2125use crate :: { historical:: check_http_error, Error , Symbols } ;
@@ -144,10 +148,11 @@ impl BatchClient<'_> {
144148 . urls
145149 . get ( "https" )
146150 . ok_or_else ( || Error :: internal ( "Missing https URL for batch file" ) ) ?;
147- self . download_file ( https_url, & output_path) . await ?;
151+ self . download_file ( https_url, & output_path, & file_desc. hash , file_desc. size )
152+ . await ?;
148153 Ok ( vec ! [ output_path] )
149154 } else {
150- let mut paths = Vec :: new ( ) ;
155+ let mut paths = Vec :: with_capacity ( job_files . len ( ) ) ;
151156 for file_desc in job_files. iter ( ) {
152157 let output_path = params
153158 . output_dir
@@ -157,31 +162,136 @@ impl BatchClient<'_> {
157162 . urls
158163 . get ( "https" )
159164 . ok_or_else ( || Error :: internal ( "Missing https URL for batch file" ) ) ?;
160- self . download_file ( https_url, & output_path) . await ?;
165+ self . download_file ( https_url, & output_path, & file_desc. hash , file_desc. size )
166+ . await ?;
161167 paths. push ( output_path) ;
162168 }
163169 Ok ( paths)
164170 }
165171 }
166172
167- async fn download_file ( & mut self , url : & str , path : impl AsRef < Path > ) -> crate :: Result < ( ) > {
173+ async fn download_file (
174+ & mut self ,
175+ url : & str ,
176+ path : & Path ,
177+ hash : & str ,
178+ exp_size : u64 ,
179+ ) -> crate :: Result < ( ) > {
180+ const MAX_RETRIES : usize = 5 ;
168181 let url = reqwest:: Url :: parse ( url)
169182 . map_err ( |e| Error :: internal ( format ! ( "Unable to parse URL: {e:?}" ) ) ) ?;
170- let resp = self . inner . get_with_path ( url. path ( ) ) ?. send ( ) . await ?;
171- let mut stream = check_http_error ( resp) . await ?. bytes_stream ( ) ;
172- info ! ( %url, path=%path. as_ref( ) . display( ) , "Downloading file" ) ;
173- let mut output = BufWriter :: new (
174- tokio:: fs:: OpenOptions :: new ( )
175- . create ( true )
176- . truncate ( true )
177- . write ( true )
178- . open ( path)
179- . await ?,
180- ) ;
181- while let Some ( chunk) = stream. next ( ) . await {
182- tokio:: io:: copy ( & mut chunk?. as_ref ( ) , & mut output) . await ?;
183+
184+ let Some ( ( hash_algo, exp_hash_hex) ) = hash. split_once ( ':' ) else {
185+ return Err ( Error :: internal ( "Unexpected hash string format {hash:?}" ) ) ;
186+ } ;
187+ let mut hasher = if hash_algo == "sha256" {
188+ Some ( Sha256 :: new ( ) )
189+ } else {
190+ warn ! (
191+ hash_algo,
192+ "Skipping checksum with unsupported hash algorithm"
193+ ) ;
194+ None
195+ } ;
196+
197+ let span = info_span ! ( "BatchDownload" , %url, path=%path. display( ) ) ;
198+ async move {
199+ let mut retries = 0 ;
200+ ' retry: loop {
201+ let mut req = self . inner . get_with_path ( url. path ( ) ) ?;
202+ match Self :: check_if_exists ( path, exp_size) . await ? {
203+ Header :: Skip => {
204+ return Ok ( ( ) ) ;
205+ }
206+ Header :: Range ( Some ( ( key, val) ) ) => {
207+ req = req. header ( key, val) ;
208+ }
209+ Header :: Range ( None ) => { }
210+ }
211+ let resp = req. send ( ) . await ?;
212+ let mut stream = check_http_error ( resp) . await ?. bytes_stream ( ) ;
213+ info ! ( "Downloading file" ) ;
214+ let mut output = BufWriter :: new (
215+ tokio:: fs:: OpenOptions :: new ( )
216+ . create ( true )
217+ . append ( true )
218+ . write ( true )
219+ . open ( path)
220+ . await ?,
221+ ) ;
222+ while let Some ( chunk) = stream. next ( ) . await {
223+ let chunk = match chunk {
224+ Ok ( chunk) => chunk,
225+ Err ( err) if retries < MAX_RETRIES => {
226+ retries += 1 ;
227+ error ! ( ?err, retries, "Retrying download" ) ;
228+ continue ' retry;
229+ }
230+ Err ( err) => {
231+ return Err ( crate :: Error :: from ( err) ) ;
232+ }
233+ } ;
234+ if retries > 0 {
235+ retries = 0 ;
236+ info ! ( "Resumed download" ) ;
237+ }
238+ if let Some ( hasher) = hasher. as_mut ( ) {
239+ hasher. update ( & chunk)
240+ }
241+ tokio:: io:: copy ( & mut chunk. as_ref ( ) , & mut output) . await ?;
242+ }
243+ debug ! ( "Completed download" ) ;
244+ Self :: verify_hash ( hasher, exp_hash_hex) . await ;
245+ return Ok ( ( ) ) ;
246+ }
247+ }
248+ . instrument ( span)
249+ . await
250+ }
251+
252+ async fn check_if_exists ( path : & Path , exp_size : u64 ) -> crate :: Result < Header > {
253+ let Ok ( metadata) = tokio:: fs:: metadata ( path) . await else {
254+ return Ok ( Header :: Range ( None ) ) ;
255+ } ;
256+ let actual_size = metadata. size ( ) ;
257+ match actual_size. cmp ( & exp_size) {
258+ Ordering :: Less => {
259+ debug ! (
260+ prev_downloaded_bytes = actual_size,
261+ total_bytes = exp_size,
262+ "Found existing file, resuming download"
263+ ) ;
264+ }
265+ Ordering :: Equal => {
266+ debug ! ( "Skipping download as file already exists and matches expected size" ) ;
267+ return Ok ( Header :: Skip ) ;
268+ }
269+ Ordering :: Greater => {
270+ return Err ( crate :: Error :: Io ( std:: io:: Error :: other ( format ! (
271+ "Batch file {} already exists with size {actual_size} which is larger than expected size {exp_size}" ,
272+ path. file_name( ) . unwrap( ) . display( ) ,
273+ ) ) ) ) ;
274+ }
275+ }
276+ Ok ( Header :: Range ( Some ( (
277+ "Range" ,
278+ format ! ( "bytes={}-" , metadata. size( ) ) ,
279+ ) ) ) )
280+ }
281+
282+ async fn verify_hash ( hasher : Option < Sha256 > , exp_hash_hex : & str ) {
283+ let Some ( hasher) = hasher else {
284+ return ;
285+ } ;
286+ let hash_hex = hasher. finalize ( ) . encode_hex :: < String > ( ) ;
287+ if hash_hex != exp_hash_hex {
288+ warn ! (
289+ hash_hex,
290+ exp_hash_hex, "Downloaded file failed checksum validation"
291+ ) ;
292+ } else {
293+ debug ! ( "Successfully verified checksum" ) ;
183294 }
184- Ok ( ( ) )
185295 }
186296
187297 const PATH_PREFIX : & ' static str = "batch" ;
@@ -403,7 +513,7 @@ pub struct DownloadParams {
403513 #[ builder( setter( transform = |dt: impl ToString | dt. to_string( ) ) ) ]
404514 pub job_id : String ,
405515 /// `None` means all files associated with the job will be downloaded.
406- #[ builder( default , setter( strip_option ) ) ]
516+ #[ builder( default , setter( transform = |filename : impl ToString | Some ( filename . to_string ( ) ) ) ) ]
407517 pub filename_to_download : Option < String > ,
408518}
409519
@@ -542,6 +652,11 @@ fn deserialize_compression<'de, D: serde::Deserializer<'de>>(
542652 Ok ( opt. unwrap_or ( Compression :: None ) )
543653}
544654
655+ enum Header {
656+ Skip ,
657+ Range ( Option < ( & ' static str , String ) > ) ,
658+ }
659+
545660#[ cfg( test) ]
546661mod tests {
547662 use reqwest:: StatusCode ;
0 commit comments