@@ -7,13 +7,16 @@ use anyhow::Result;
77use bigdecimal:: BigDecimal ;
88use sqlx:: { PgPool , Row } ;
99
10+ use crate :: test_config:: TestConfig ;
11+
1012/// Unified database checker for both V1 and V2 TAP tables
1113pub struct DatabaseChecker {
1214 pool : PgPool ,
15+ cfg : TestConfig ,
1316}
1417
1518/// TAP version enum to specify which tables to query
16- #[ derive( Debug , Clone , Copy ) ]
19+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
1720pub enum TapVersion {
1821 V1 , // Legacy receipt aggregator tables
1922 V2 , // Horizon tables
@@ -106,9 +109,9 @@ pub struct RecentReceipt {
106109
107110impl DatabaseChecker {
108111 /// Create new DatabaseChecker with database connection
109- pub async fn new ( database_url : & str ) -> Result < Self > {
110- let pool = PgPool :: connect ( database_url) . await ?;
111- Ok ( Self { pool } )
112+ pub async fn new ( cfg : TestConfig ) -> Result < Self > {
113+ let pool = PgPool :: connect ( cfg . database_url ( ) ) . await ?;
114+ Ok ( Self { pool, cfg } )
112115 }
113116
114117 /// Get combined V1 and V2 state for comprehensive testing
@@ -644,7 +647,7 @@ impl DatabaseChecker {
644647 TapVersion :: V1 => sqlx:: query_scalar (
645648 r#"
646649 SELECT COUNT(*)
647- FROM tap_ravs
650+ FROM scalar_tap_ravs
648651 WHERE allocation_id = $1
649652 AND LOWER(sender_address) = $2
650653 "# ,
@@ -661,7 +664,7 @@ impl DatabaseChecker {
661664
662665 /// Get the total value of receipts for an identifier that don't have a RAV yet
663666 pub async fn get_pending_receipt_value (
664- & self ,
667+ & mut self ,
665668 identifier : & str , // collection_id for V2, allocation_id for V1
666669 payer : & str ,
667670 version : TapVersion ,
@@ -670,36 +673,47 @@ impl DatabaseChecker {
670673
671674 let pending_value: Option < BigDecimal > = match version {
672675 TapVersion :: V2 => {
676+ // Sum receipts for this collection/payer that are newer than the last RAV
677+ // and older than the timestamp buffer cutoff (eligible to aggregate)
678+ let buffer_secs = self . get_timestamp_buffer_secs ( ) ?;
679+ let current_time_ns = std:: time:: SystemTime :: now ( )
680+ . duration_since ( std:: time:: UNIX_EPOCH )
681+ . unwrap ( )
682+ . as_nanos ( ) as u64 ;
683+ let cutoff_ns = current_time_ns - buffer_secs * 1_000_000_000 ;
684+
673685 sqlx:: query_scalar (
674686 r#"
675- SELECT SUM(r.value)
676- FROM tap_horizon_receipts r
677- LEFT JOIN tap_horizon_ravs rav ON (
678- r.collection_id = rav.collection_id
679- AND LOWER(r.payer) = LOWER(rav.payer)
680- AND LOWER(r.service_provider) = LOWER(rav.service_provider)
681- AND LOWER(r.data_service) = LOWER(rav.data_service)
687+ WITH last_rav AS (
688+ SELECT COALESCE(MAX(timestamp_ns), 0) AS last_ts
689+ FROM tap_horizon_ravs rav
690+ WHERE rav.collection_id = $1
691+ AND LOWER(rav.payer) = $2
682692 )
683- WHERE r.collection_id = $1
684- AND LOWER(r.payer) = $2
685- AND rav.collection_id IS NULL
693+ SELECT COALESCE(SUM(r.value), 0)
694+ FROM tap_horizon_receipts r, last_rav lr
695+ WHERE r.collection_id = $1
696+ AND LOWER(r.payer) = $2
697+ AND r.timestamp_ns > lr.last_ts
698+ AND r.timestamp_ns <= $3
686699 "# ,
687700 )
688701 . bind ( identifier)
689702 . bind ( & normalized_payer)
703+ . bind ( cutoff_ns as i64 )
690704 . fetch_one ( & self . pool )
691705 . await ?
692706 }
693707 TapVersion :: V1 => sqlx:: query_scalar (
694708 r#"
695709 SELECT SUM(r.value)
696- FROM tap_receipts r
697- LEFT JOIN tap_ravs rav ON (
710+ FROM scalar_tap_receipts r
711+ LEFT JOIN scalar_tap_ravs rav ON (
698712 r.allocation_id = rav.allocation_id
699- AND LOWER(r.sender_address ) = LOWER(rav.sender_address)
713+ AND LOWER(r.signer_address ) = LOWER(rav.sender_address)
700714 )
701715 WHERE r.allocation_id = $1
702- AND LOWER(r.sender_address ) = $2
716+ AND LOWER(r.signer_address ) = $2
703717 AND rav.allocation_id IS NULL
704718 "# ,
705719 )
@@ -714,6 +728,7 @@ impl DatabaseChecker {
714728 }
715729
716730 /// Wait for a RAV to be created with timeout
731+ /// V1 only
717732 pub async fn wait_for_rav_creation (
718733 & self ,
719734 payer : & str ,
@@ -722,6 +737,9 @@ impl DatabaseChecker {
722737 check_interval_seconds : u64 ,
723738 version : TapVersion ,
724739 ) -> Result < bool > {
740+ if TapVersion :: V2 == version {
741+ anyhow:: bail!( "wait_for_rav_creation is only supported for V1 TAP" ) ;
742+ }
725743 let start_time = std:: time:: Instant :: now ( ) ;
726744 let timeout_duration = std:: time:: Duration :: from_secs ( timeout_seconds) ;
727745
@@ -919,7 +937,7 @@ impl DatabaseChecker {
919937
920938 /// Diagnostic function to analyze timestamp buffer issues during RAV generation
921939 /// This simulates the exact logic used in tap_core's Manager::collect_receipts
922- pub async fn diagnose_timestamp_buffer (
940+ async fn diagnose_timestamp_buffer_impl (
923941 & self ,
924942 payer : & str ,
925943 identifier : & str , // collection_id for V2, allocation_id for V1
@@ -1107,4 +1125,211 @@ impl DatabaseChecker {
11071125
11081126 Ok ( ( ) )
11091127 }
1128+
1129+ /// Get the trigger value (wei) from tap-agent configuration
1130+ pub fn get_trigger_value_wei ( & mut self ) -> Result < u128 > {
1131+ self . cfg . get_tap_trigger_value_wei ( )
1132+ }
1133+
1134+ /// Get the timestamp buffer seconds from tap-agent configuration
1135+ pub fn get_timestamp_buffer_secs ( & mut self ) -> Result < u64 > {
1136+ self . cfg . get_tap_timestamp_buffer_secs ( )
1137+ }
1138+
1139+ /// Diagnostic function that uses tap-agent configuration
1140+ pub async fn diagnose_timestamp_buffer (
1141+ & mut self ,
1142+ payer : & str ,
1143+ identifier : & str , // collection_id for V2, allocation_id for V1
1144+ version : TapVersion ,
1145+ ) -> Result < ( ) > {
1146+ let buffer_seconds = self . get_timestamp_buffer_secs ( ) ?;
1147+ self . diagnose_timestamp_buffer_impl ( payer, identifier, buffer_seconds, version)
1148+ . await
1149+ }
1150+
1151+ /// Print summary with tap-agent configuration context
1152+ pub async fn print_summary ( & mut self , payer : & str , version : TapVersion ) -> Result < ( ) > {
1153+ let state = self . get_state ( payer, version) . await ?;
1154+ let trigger_value = self . get_trigger_value_wei ( ) ?;
1155+ let buffer_secs = self . get_timestamp_buffer_secs ( ) ?;
1156+ let max_willing_to_lose = self . cfg . get_tap_max_amount_willing_to_lose_grt ( ) ?;
1157+ let trigger_divisor = self . cfg . get_tap_trigger_value_divisor ( ) ?;
1158+
1159+ let version_name = match version {
1160+ TapVersion :: V1 => "V1 (Legacy)" ,
1161+ TapVersion :: V2 => "V2 (Horizon)" ,
1162+ } ;
1163+
1164+ println ! ( "\n === {} TAP Database State (Config) ===" , version_name) ;
1165+ println ! ( "Payer: {}" , payer) ;
1166+
1167+ // Show tap-agent configuration values
1168+ println ! ( "🔧 Tap-Agent Configuration:" ) ;
1169+ println ! (
1170+ " Max Amount Willing to Lose: {:.6} GRT" ,
1171+ max_willing_to_lose
1172+ ) ;
1173+ println ! ( " Trigger Value Divisor: {}" , trigger_divisor) ;
1174+ println ! (
1175+ " → Calculated Trigger Value: {} wei ({:.6} GRT)" ,
1176+ trigger_value,
1177+ trigger_value as f64 / 1e18
1178+ ) ;
1179+ println ! (
1180+ " → Formula: {:.6} GRT / {} = {:.6} GRT" ,
1181+ max_willing_to_lose,
1182+ trigger_divisor,
1183+ trigger_value as f64 / 1e18
1184+ ) ;
1185+ println ! ( " Timestamp Buffer: {} seconds" , buffer_secs) ;
1186+
1187+ println ! ( "📊 Database Statistics:" ) ;
1188+ println ! (
1189+ " Receipts: {} (total value: {} wei)" ,
1190+ state. receipt_count, state. receipt_value
1191+ ) ;
1192+ println ! (
1193+ " RAVs: {} (total value: {} wei)" ,
1194+ state. rav_count, state. rav_value
1195+ ) ;
1196+ println ! ( " Pending RAV Collections: {}" , state. pending_rav_count) ;
1197+ println ! ( " Failed RAV Requests: {}" , state. failed_rav_count) ;
1198+ println ! ( " Invalid Receipts: {}" , state. invalid_receipt_count) ;
1199+
1200+ // Calculate trigger progress
1201+ if state. pending_rav_count > 0 {
1202+ // Get total pending value across all collections
1203+ let total_pending_value = self . get_total_pending_value ( payer, version) . await ?;
1204+ let progress_percentage = ( total_pending_value. clone ( ) * BigDecimal :: from ( 100 ) )
1205+ / BigDecimal :: from ( trigger_value) ;
1206+
1207+ println ! ( "\n 📈 Trigger Analysis:" ) ;
1208+ println ! (
1209+ " Total Pending Value: {} wei ({:.6} GRT)" ,
1210+ total_pending_value,
1211+ total_pending_value
1212+ . to_string( )
1213+ . parse:: <f64 >( )
1214+ . unwrap_or( 0.0 )
1215+ / 1e18
1216+ ) ;
1217+ println ! (
1218+ " Progress to Trigger: {:.1}%" ,
1219+ progress_percentage
1220+ . to_string( )
1221+ . parse:: <f64 >( )
1222+ . unwrap_or( 0.0 )
1223+ ) ;
1224+
1225+ if total_pending_value >= BigDecimal :: from ( trigger_value) {
1226+ println ! ( " ✅ Ready to trigger RAV!" ) ;
1227+ } else {
1228+ let needed = BigDecimal :: from ( trigger_value) - & total_pending_value;
1229+ println ! (
1230+ " ⏳ Need {} wei more ({:.6} GRT)" ,
1231+ needed,
1232+ needed. to_string( ) . parse:: <f64 >( ) . unwrap_or( 0.0 ) / 1e18
1233+ ) ;
1234+ }
1235+ }
1236+
1237+ Ok ( ( ) )
1238+ }
1239+
1240+ /// Get total pending value across all collections for a payer
1241+ async fn get_total_pending_value (
1242+ & self ,
1243+ payer : & str ,
1244+ version : TapVersion ,
1245+ ) -> Result < BigDecimal > {
1246+ let normalized_payer = payer. trim_start_matches ( "0x" ) . to_lowercase ( ) ;
1247+
1248+ let total_pending: Option < BigDecimal > = match version {
1249+ TapVersion :: V2 => {
1250+ sqlx:: query_scalar (
1251+ r#"
1252+ SELECT SUM(r.value)
1253+ FROM tap_horizon_receipts r
1254+ LEFT JOIN tap_horizon_ravs rav ON (
1255+ r.collection_id = rav.collection_id
1256+ AND LOWER(r.payer) = LOWER(rav.payer)
1257+ AND LOWER(r.service_provider) = LOWER(rav.service_provider)
1258+ AND LOWER(r.data_service) = LOWER(rav.data_service)
1259+ )
1260+ WHERE LOWER(r.payer) = $1 AND rav.collection_id IS NULL
1261+ "# ,
1262+ )
1263+ . bind ( & normalized_payer)
1264+ . fetch_one ( & self . pool )
1265+ . await ?
1266+ }
1267+ TapVersion :: V1 => sqlx:: query_scalar (
1268+ r#"
1269+ SELECT SUM(r.value)
1270+ FROM scalar_tap_receipts r
1271+ LEFT JOIN scalar_tap_ravs rav ON (
1272+ r.allocation_id = rav.allocation_id
1273+ AND LOWER(r.signer_address) = LOWER(rav.sender_address)
1274+ )
1275+ WHERE LOWER(r.signer_address) = $1 AND rav.allocation_id IS NULL
1276+ "# ,
1277+ )
1278+ . bind ( & normalized_payer)
1279+ . fetch_optional ( & self . pool )
1280+ . await ?
1281+ . flatten ( ) ,
1282+ } ;
1283+
1284+ Ok ( total_pending. unwrap_or_else ( || BigDecimal :: from_str ( "0" ) . unwrap ( ) ) )
1285+ }
1286+
1287+ /// Check for V2 RAV generation by looking at both count and value changes
1288+ /// Returns (rav_was_created, rav_value_increased)
1289+ pub async fn check_v2_rav_progress (
1290+ & self ,
1291+ payer : & str ,
1292+ initial_rav_count : i64 ,
1293+ initial_rav_value : & BigDecimal ,
1294+ version : TapVersion ,
1295+ ) -> Result < ( bool , bool ) > {
1296+ let current_state = self . get_state ( payer, version) . await ?;
1297+
1298+ // Check if new RAV was created (0 → 1)
1299+ let rav_was_created = current_state. rav_count > initial_rav_count;
1300+
1301+ // Check if existing RAV value increased (for V2 updates)
1302+ let rav_value_increased = & current_state. rav_value > initial_rav_value;
1303+
1304+ Ok ( ( rav_was_created, rav_value_increased) )
1305+ }
1306+
1307+ /// Enhanced wait for RAV creation that handles V1
1308+ pub async fn wait_for_rav_creation_or_update (
1309+ & self ,
1310+ payer : & str ,
1311+ initial_rav_count : i64 ,
1312+ initial_rav_value : BigDecimal ,
1313+ timeout_seconds : u64 ,
1314+ check_interval_seconds : u64 ,
1315+ version : TapVersion ,
1316+ ) -> Result < ( bool , bool ) > {
1317+ let start_time = std:: time:: Instant :: now ( ) ;
1318+ let timeout_duration = std:: time:: Duration :: from_secs ( timeout_seconds) ;
1319+
1320+ while start_time. elapsed ( ) < timeout_duration {
1321+ let ( rav_created, rav_increased) = self
1322+ . check_v2_rav_progress ( payer, initial_rav_count, & initial_rav_value, version)
1323+ . await ?;
1324+
1325+ // Success for either new RAV creation or value increase
1326+ if rav_created || rav_increased {
1327+ return Ok ( ( rav_created, rav_increased) ) ;
1328+ }
1329+
1330+ tokio:: time:: sleep ( std:: time:: Duration :: from_secs ( check_interval_seconds) ) . await ;
1331+ }
1332+
1333+ Ok ( ( false , false ) )
1334+ }
11101335}
0 commit comments