@@ -27,25 +27,50 @@ enum ProcessReceiptError {
2727 Both ( anyhow:: Error , anyhow:: Error ) ,
2828}
2929
30+ /// Indicates which versions of Receipts where processed
31+ /// It's intended to be used for migration tests
32+ #[ derive( Debug , PartialEq , Eq ) ]
33+ pub enum Processed {
34+ V1 ,
35+ V2 ,
36+ All ,
37+ None ,
38+ }
39+
3040impl InnerContext {
3141 async fn process_db_receipts (
3242 & self ,
3343 buffer : Vec < DatabaseReceipt > ,
34- ) -> Result < ( ) , ProcessReceiptError > {
44+ ) -> Result < Processed , ProcessReceiptError > {
3545 let ( v1_receipts, v2_receipts) : ( Vec < _ > , Vec < _ > ) =
3646 buffer. into_iter ( ) . partition_map ( |r| match r {
3747 DatabaseReceipt :: V1 ( db_receipt_v1) => Either :: Left ( db_receipt_v1) ,
3848 DatabaseReceipt :: V2 ( db_receipt_v2) => Either :: Right ( db_receipt_v2) ,
3949 } ) ;
40- let ( insert_v1, insert_v2) = tokio:: join!(
41- self . store_receipts_v1( v1_receipts) ,
42- self . store_receipts_v2( v2_receipts)
43- ) ;
50+
51+ let ( insert_v1, insert_v2) = match ( v1_receipts. is_empty ( ) , v2_receipts. is_empty ( ) ) {
52+ ( true , true ) => ( None , None ) ,
53+ ( false , true ) => ( Some ( self . store_receipts_v1 ( v1_receipts) . await ) , None ) ,
54+ ( true , false ) => ( None , Some ( self . store_receipts_v2 ( v2_receipts) . await ) ) ,
55+ ( false , false ) => {
56+ let ( v1, v2) = tokio:: join!(
57+ self . store_receipts_v1( v1_receipts) ,
58+ self . store_receipts_v2( v2_receipts) ,
59+ ) ;
60+ ( Some ( v1) , Some ( v2) )
61+ }
62+ } ;
63+
4464 match ( insert_v1, insert_v2) {
45- ( Err ( e1) , Err ( e2) ) => Err ( ProcessReceiptError :: Both ( e1. into ( ) , e2. into ( ) ) ) ,
46- ( Err ( e1) , _) => Err ( ProcessReceiptError :: V1 ( e1. into ( ) ) ) ,
47- ( _, Err ( e2) ) => Err ( ProcessReceiptError :: V2 ( e2. into ( ) ) ) ,
48- _ => Ok ( ( ) ) ,
65+ ( Some ( Err ( e1) ) , Some ( Err ( e2) ) ) => Err ( ProcessReceiptError :: Both ( e1. into ( ) , e2. into ( ) ) ) ,
66+ ( Some ( Err ( e1) ) , _) => Err ( ProcessReceiptError :: V1 ( e1. into ( ) ) ) ,
67+ ( _, Some ( Err ( e2) ) ) => Err ( ProcessReceiptError :: V2 ( e2. into ( ) ) ) ,
68+
69+ // only useful for testing
70+ ( Some ( Ok ( _) ) , None ) => Ok ( Processed :: V1 ) ,
71+ ( None , Some ( Ok ( _) ) ) => Ok ( Processed :: V2 ) ,
72+ ( Some ( Ok ( _) ) , Some ( Ok ( _) ) ) => Ok ( Processed :: All ) ,
73+ ( None , None ) => Ok ( Processed :: None ) ,
4974 }
5075 }
5176
@@ -305,3 +330,176 @@ impl DbReceiptV2 {
305330 } )
306331 }
307332}
333+
334+ #[ cfg( test) ]
335+ mod tests {
336+ use std:: { path:: PathBuf , sync:: LazyLock } ;
337+
338+ use futures:: future:: BoxFuture ;
339+ use sqlx:: {
340+ migrate:: { MigrationSource , Migrator } ,
341+ PgPool ,
342+ } ;
343+ use test_assets:: {
344+ create_signed_receipt, create_signed_receipt_v2, SignedReceiptRequest , INDEXER_ALLOCATIONS ,
345+ TAP_EIP712_DOMAIN ,
346+ } ;
347+
348+ use crate :: tap:: {
349+ receipt_store:: {
350+ DatabaseReceipt , DbReceiptV1 , DbReceiptV2 , InnerContext , ProcessReceiptError , Processed ,
351+ } ,
352+ AdapterError ,
353+ } ;
354+
355+ async fn create_v1 ( ) -> DatabaseReceipt {
356+ let alloc = INDEXER_ALLOCATIONS . values ( ) . next ( ) . unwrap ( ) . clone ( ) ;
357+ let v1 = create_signed_receipt (
358+ SignedReceiptRequest :: builder ( )
359+ . allocation_id ( alloc. id )
360+ . value ( 100 )
361+ . build ( ) ,
362+ )
363+ . await ;
364+ let v1 = DatabaseReceipt :: V1 ( DbReceiptV1 :: from_receipt ( & v1, & TAP_EIP712_DOMAIN ) . unwrap ( ) ) ;
365+ v1
366+ }
367+
368+ async fn create_v2 ( ) -> DatabaseReceipt {
369+ let v2 = create_signed_receipt_v2 ( ) . call ( ) . await ;
370+ let v2 = DatabaseReceipt :: V2 ( DbReceiptV2 :: from_receipt ( & v2, & TAP_EIP712_DOMAIN ) . unwrap ( ) ) ;
371+ v2
372+ }
373+
374+ mod when_all_migrations_are_run {
375+ use super :: * ;
376+
377+ #[ rstest:: rstest]
378+ #[ case( Processed :: None , async { vec![ ] } ) ]
379+ #[ case( Processed :: V1 , async { vec![ create_v1( ) . await ] } ) ]
380+ #[ case( Processed :: V2 , async { vec![ create_v2( ) . await ] } ) ]
381+ #[ case( Processed :: All , async { vec![ create_v2( ) . await , create_v1( ) . await ] } ) ]
382+ #[ sqlx:: test( migrations = "../../migrations" ) ]
383+ async fn v1_and_v2_are_processed_successfully (
384+ #[ ignore] pgpool : PgPool ,
385+ #[ case] expected : Processed ,
386+ #[ future( awt) ]
387+ #[ case]
388+ receipts : Vec < DatabaseReceipt > ,
389+ ) {
390+ let context = InnerContext { pgpool } ;
391+
392+ let res = context. process_db_receipts ( receipts) . await . unwrap ( ) ;
393+
394+ assert_eq ! ( res, expected) ;
395+ }
396+ }
397+
398+ mod when_horizon_migrations_are_ignored {
399+ use super :: * ;
400+
401+ #[ sqlx:: test( migrator = "WITHOUT_HORIZON_MIGRATIONS" ) ]
402+ async fn test_empty_receipts_are_processed_successfully ( pgpool : PgPool ) {
403+ let context = InnerContext { pgpool } ;
404+
405+ let res = context. process_db_receipts ( vec ! [ ] ) . await . unwrap ( ) ;
406+
407+ assert_eq ! ( res, Processed :: None ) ;
408+ }
409+
410+ #[ sqlx:: test( migrator = "WITHOUT_HORIZON_MIGRATIONS" ) ]
411+ async fn test_v1_receipts_are_processed_successfully ( pgpool : PgPool ) {
412+ let context = InnerContext { pgpool } ;
413+
414+ let v1 = create_v1 ( ) . await ;
415+ let receipts = vec ! [ v1] ;
416+
417+ let res = context. process_db_receipts ( receipts) . await . unwrap ( ) ;
418+
419+ assert_eq ! ( res, Processed :: V1 ) ;
420+ }
421+
422+ #[ rstest:: rstest]
423+ #[ case( async { vec![ create_v2( ) . await ] } ) ]
424+ #[ case( async { vec![ create_v2( ) . await , create_v1( ) . await ] } ) ]
425+ #[ sqlx:: test( migrator = "WITHOUT_HORIZON_MIGRATIONS" ) ]
426+ async fn test_cases_with_v2_receipts_fails_to_process (
427+ #[ ignore] pgpool : PgPool ,
428+ #[ future( awt) ]
429+ #[ case]
430+ receipts : Vec < DatabaseReceipt > ,
431+ ) {
432+ let context = InnerContext { pgpool } ;
433+
434+ let error = context. process_db_receipts ( receipts) . await . unwrap_err ( ) ;
435+
436+ let ProcessReceiptError :: V2 ( error) = error else {
437+ panic ! ( )
438+ } ;
439+ let d = error. downcast_ref :: < AdapterError > ( ) . unwrap ( ) . to_string ( ) ;
440+
441+ assert_eq ! (
442+ d,
443+ "error returned from database: relation \" tap_horizon_receipts\" does not exist"
444+ ) ;
445+ }
446+
447+ pub static WITHOUT_HORIZON_MIGRATIONS : LazyLock < Migrator > = LazyLock :: new ( create_migrator) ;
448+
449+ pub fn create_migrator ( ) -> Migrator {
450+ futures:: executor:: block_on ( Migrator :: new ( MigrationRunner :: new (
451+ "../../migrations" ,
452+ [ "horizon" ] ,
453+ ) ) )
454+ . unwrap ( )
455+ }
456+
457+ #[ derive( Debug ) ]
458+ pub struct MigrationRunner {
459+ migration_path : PathBuf ,
460+ ignored_migrations : Vec < String > ,
461+ }
462+
463+ impl MigrationRunner {
464+ /// Construct a new MigrationRunner that does not apply the given migrations.
465+ ///
466+ /// `ignored_migrations` is any iterable of strings that describes which
467+ /// migrations to be ignored.
468+ pub fn new < I > ( path : impl Into < PathBuf > , ignored_migrations : I ) -> Self
469+ where
470+ I : IntoIterator ,
471+ I :: Item : Into < String > ,
472+ {
473+ Self {
474+ migration_path : path. into ( ) ,
475+ ignored_migrations : ignored_migrations. into_iter ( ) . map ( Into :: into) . collect ( ) ,
476+ }
477+ }
478+ }
479+
480+ impl MigrationSource < ' static > for MigrationRunner {
481+ fn resolve (
482+ self ,
483+ ) -> BoxFuture < ' static , Result < Vec < sqlx:: migrate:: Migration > , sqlx:: error:: BoxDynError > >
484+ {
485+ Box :: pin ( async move {
486+ let canonical = self . migration_path . canonicalize ( ) ?;
487+ let migrations_with_paths =
488+ sqlx:: migrate:: resolve_blocking ( & canonical) . unwrap ( ) ;
489+
490+ let migrations_with_paths = migrations_with_paths
491+ . into_iter ( )
492+ . filter ( |( _, p) | {
493+ let path = p. to_str ( ) . unwrap ( ) ;
494+ self . ignored_migrations
495+ . iter ( )
496+ . any ( |ignored| !path. contains ( ignored) )
497+ } )
498+ . collect :: < Vec < _ > > ( ) ;
499+
500+ Ok ( migrations_with_paths. into_iter ( ) . map ( |( m, _p) | m) . collect ( ) )
501+ } )
502+ }
503+ }
504+ }
505+ }
0 commit comments