@@ -7,7 +7,7 @@ use futures::{Stream, TryStreamExt};
77use reqwest:: { header:: ACCEPT , RequestBuilder } ;
88use tokio:: {
99 fs:: File ,
10- io:: { AsyncReadExt , AsyncWriteExt , BufWriter } ,
10+ io:: { AsyncReadExt , AsyncWriteExt , BufReader , BufWriter } ,
1111} ;
1212use tokio_util:: { bytes:: Bytes , io:: StreamReader } ;
1313use typed_builder:: TypedBuilder ;
@@ -53,9 +53,10 @@ impl TimeseriesClient<'_> {
5353 params. limit ,
5454 )
5555 . await ?;
56- let mut decoder: AsyncDbnDecoder < _ > = AsyncDbnDecoder :: with_zstd_buffer ( reader) . await ?;
57- decoder. set_upgrade_policy ( params. upgrade_policy ) ?;
58- Ok ( decoder)
56+ Ok (
57+ AsyncDbnDecoder :: with_upgrade_policy ( zstd_decoder ( reader) , params. upgrade_policy )
58+ . await ?,
59+ )
5960 }
6061
6162 /// Makes a streaming request for timeseries data from Databento.
@@ -86,15 +87,21 @@ impl TimeseriesClient<'_> {
8687 params. limit ,
8788 )
8889 . await ?;
89- let mut http_decoder = AsyncDbnDecoder :: with_zstd_buffer ( reader) . await ?;
90- http_decoder. set_upgrade_policy ( params. upgrade_policy ) ?;
90+ let mut http_decoder =
91+ AsyncDbnDecoder :: with_upgrade_policy ( zstd_decoder ( reader) , params. upgrade_policy )
92+ . await ?;
9193 let file = BufWriter :: new ( File :: create ( & params. path ) . await ?) ;
9294 let mut encoder = AsyncDbnEncoder :: with_zstd ( file, http_decoder. metadata ( ) ) . await ?;
9395 while let Some ( rec_ref) = http_decoder. decode_record_ref ( ) . await ? {
9496 encoder. encode_record_ref ( rec_ref) . await ?;
9597 }
9698 encoder. get_mut ( ) . shutdown ( ) . await ?;
97- Ok ( AsyncDbnDecoder :: from_zstd_file ( & params. path ) . await ?)
99+ Ok ( AsyncDbnDecoder :: with_upgrade_policy (
100+ zstd_decoder ( BufReader :: new ( File :: open ( & params. path ) . await ?) ) ,
101+ // Applied upgrade policy during initial decoding
102+ VersionUpgradePolicy :: AsIs ,
103+ )
104+ . await ?)
98105 }
99106
100107 #[ allow( clippy:: too_many_arguments) ] // private method
@@ -241,10 +248,21 @@ impl GetRangeParams {
241248 }
242249}
243250
251+ fn zstd_decoder < R > ( reader : R ) -> async_compression:: tokio:: bufread:: ZstdDecoder < R >
252+ where
253+ R : tokio:: io:: AsyncBufReadExt + Unpin ,
254+ {
255+ let mut zstd_decoder = async_compression:: tokio:: bufread:: ZstdDecoder :: new ( reader) ;
256+ // explicitly enable decoding multiple frames
257+ zstd_decoder. multiple_members ( true ) ;
258+ zstd_decoder
259+ }
260+
244261#[ cfg( test) ]
245262mod tests {
246263 use dbn:: { record:: TradeMsg , Dataset } ;
247264 use reqwest:: StatusCode ;
265+ use rstest:: * ;
248266 use time:: macros:: datetime;
249267 use wiremock:: {
250268 matchers:: { basic_auth, method, path} ,
@@ -260,8 +278,12 @@ mod tests {
260278
261279 const API_KEY : & str = "test-API" ;
262280
281+ #[ rstest]
282+ #[ case( VersionUpgradePolicy :: AsIs , 1 ) ]
283+ #[ case( VersionUpgradePolicy :: UpgradeToV2 , 2 ) ]
284+ #[ case( VersionUpgradePolicy :: UpgradeToV3 , 3 ) ]
263285 #[ tokio:: test]
264- async fn test_get_range ( ) {
286+ async fn test_get_range ( # [ case ] upgrade_policy : VersionUpgradePolicy , # [ case ] exp_version : u8 ) {
265287 const START : time:: OffsetDateTime = datetime ! ( 2023 - 06 - 14 00 : 00 UTC ) ;
266288 const END : time:: OffsetDateTime = datetime ! ( 2023 - 06 - 17 00 : 00 UTC ) ;
267289 const SCHEMA : Schema = Schema :: Trades ;
@@ -299,19 +321,29 @@ mod tests {
299321 . schema ( SCHEMA )
300322 . symbols ( vec ! [ "SPOT" , "AAPL" ] )
301323 . date_time_range ( ( START , END ) )
324+ . upgrade_policy ( upgrade_policy)
302325 . build ( ) ,
303326 )
304327 . await
305328 . unwrap ( ) ;
306- assert_eq ! ( decoder. metadata( ) . schema. unwrap( ) , SCHEMA ) ;
329+ let metadata = decoder. metadata ( ) ;
330+ assert_eq ! ( metadata. schema. unwrap( ) , SCHEMA ) ;
331+ assert_eq ! ( metadata. version, exp_version) ;
307332 // Two records
308333 decoder. decode_record :: < TradeMsg > ( ) . await . unwrap ( ) . unwrap ( ) ;
309334 decoder. decode_record :: < TradeMsg > ( ) . await . unwrap ( ) . unwrap ( ) ;
310335 assert ! ( decoder. decode_record:: <TradeMsg >( ) . await . unwrap( ) . is_none( ) ) ;
311336 }
312337
338+ #[ rstest]
339+ #[ case( VersionUpgradePolicy :: AsIs , 1 ) ]
340+ #[ case( VersionUpgradePolicy :: UpgradeToV2 , 2 ) ]
341+ #[ case( VersionUpgradePolicy :: UpgradeToV3 , 3 ) ]
313342 #[ tokio:: test]
314- async fn test_get_range_to_file ( ) {
343+ async fn test_get_range_to_file (
344+ #[ case] upgrade_policy : VersionUpgradePolicy ,
345+ #[ case] exp_version : u8 ,
346+ ) {
315347 const START : time:: OffsetDateTime = datetime ! ( 2024 - 05 - 17 00 : 00 UTC ) ;
316348 const END : time:: OffsetDateTime = datetime ! ( 2024 - 05 - 18 00 : 00 UTC ) ;
317349 const SCHEMA : Schema = Schema :: Trades ;
@@ -354,11 +386,14 @@ mod tests {
354386 . stype_in ( SType :: Parent )
355387 . date_time_range ( ( START , END ) )
356388 . path ( path. clone ( ) )
389+ . upgrade_policy ( upgrade_policy)
357390 . build ( ) ,
358391 )
359392 . await
360393 . unwrap ( ) ;
361- assert_eq ! ( decoder. metadata( ) . schema. unwrap( ) , SCHEMA ) ;
394+ let metadata = decoder. metadata ( ) ;
395+ assert_eq ! ( metadata. schema. unwrap( ) , SCHEMA ) ;
396+ assert_eq ! ( metadata. version, exp_version) ;
362397 // Two records
363398 decoder. decode_record :: < TradeMsg > ( ) . await . unwrap ( ) . unwrap ( ) ;
364399 decoder. decode_record :: < TradeMsg > ( ) . await . unwrap ( ) . unwrap ( ) ;
0 commit comments