@@ -16,9 +16,14 @@ use crate::{
1616 tap:: { CheckingReceipt , TapReceipt } ,
1717} ;
1818
19+ enum DenyListVersion {
20+ V1 ,
21+ V2 ,
22+ }
23+
1924pub struct DenyListCheck {
20- sender_denylist : Arc < RwLock < HashSet < Address > > > ,
21- _sender_denylist_watcher_handle : Arc < tokio :: task :: JoinHandle < ( ) > > ,
25+ sender_denylist_v1 : Arc < RwLock < HashSet < Address > > > ,
26+ sender_denylist_v2 : Arc < RwLock < HashSet < Address > > > ,
2227 sender_denylist_watcher_cancel_token : tokio_util:: sync:: CancellationToken ,
2328
2429 #[ cfg( test) ]
@@ -29,43 +34,65 @@ impl DenyListCheck {
2934 pub async fn new ( pgpool : PgPool ) -> Self {
3035 // Listen to pg_notify events. We start it before updating the sender_denylist so that we
3136 // don't miss any updates. PG will buffer the notifications until we start consuming them.
32- let mut pglistener = PgListener :: connect_with ( & pgpool. clone ( ) ) . await . unwrap ( ) ;
33- pglistener
37+ let mut pglistener_v1 = PgListener :: connect_with ( & pgpool. clone ( ) ) . await . unwrap ( ) ;
38+ let mut pglistener_v2 = PgListener :: connect_with ( & pgpool. clone ( ) ) . await . unwrap ( ) ;
39+ pglistener_v1
3440 . listen ( "scalar_tap_deny_notification" )
3541 . await
3642 . expect (
3743 "should be able to subscribe to Postgres Notify events on the channel \
3844 'scalar_tap_deny_notification'",
3945 ) ;
4046
47+ pglistener_v2
48+ . listen ( "tap_horizon_deny_notification" )
49+ . await
50+ . expect (
51+ "should be able to subscribe to Postgres Notify events on the channel \
52+ 'tap_horizon_deny_notification'",
53+ ) ;
54+
4155 // Fetch the denylist from the DB
42- let sender_denylist = Arc :: new ( RwLock :: new ( HashSet :: new ( ) ) ) ;
43- Self :: sender_denylist_reload ( pgpool. clone ( ) , sender_denylist. clone ( ) )
56+ let sender_denylist_v1 = Arc :: new ( RwLock :: new ( HashSet :: new ( ) ) ) ;
57+ let sender_denylist_v2 = Arc :: new ( RwLock :: new ( HashSet :: new ( ) ) ) ;
58+ Self :: sender_denylist_reload_v1 ( pgpool. clone ( ) , sender_denylist_v1. clone ( ) )
4459 . await
4560 . expect ( "should be able to fetch the sender_denylist from the DB on startup" ) ;
4661
4762 #[ cfg( test) ]
4863 let notify = std:: sync:: Arc :: new ( tokio:: sync:: Notify :: new ( ) ) ;
4964
5065 let sender_denylist_watcher_cancel_token = tokio_util:: sync:: CancellationToken :: new ( ) ;
51- let sender_denylist_watcher_handle = Arc :: new ( tokio:: spawn ( Self :: sender_denylist_watcher (
66+ tokio:: spawn ( Self :: sender_denylist_watcher (
5267 pgpool. clone ( ) ,
53- pglistener ,
54- sender_denylist . clone ( ) ,
68+ pglistener_v1 ,
69+ sender_denylist_v1 . clone ( ) ,
5570 sender_denylist_watcher_cancel_token. clone ( ) ,
71+ DenyListVersion :: V1 ,
5672 #[ cfg( test) ]
5773 notify. clone ( ) ,
58- ) ) ) ;
74+ ) ) ;
75+
76+ tokio:: spawn ( Self :: sender_denylist_watcher (
77+ pgpool. clone ( ) ,
78+ pglistener_v2,
79+ sender_denylist_v1. clone ( ) ,
80+ sender_denylist_watcher_cancel_token. clone ( ) ,
81+ DenyListVersion :: V2 ,
82+ #[ cfg( test) ]
83+ notify. clone ( ) ,
84+ ) ) ;
85+
5986 Self {
60- sender_denylist ,
61- _sender_denylist_watcher_handle : sender_denylist_watcher_handle ,
87+ sender_denylist_v1 ,
88+ sender_denylist_v2 ,
6289 sender_denylist_watcher_cancel_token,
6390 #[ cfg( test) ]
6491 notify,
6592 }
6693 }
6794
68- async fn sender_denylist_reload (
95+ async fn sender_denylist_reload_v1 (
6996 pgpool : PgPool ,
7097 denylist_rwlock : Arc < RwLock < HashSet < Address > > > ,
7198 ) -> anyhow:: Result < ( ) > {
@@ -86,11 +113,33 @@ impl DenyListCheck {
86113 Ok ( ( ) )
87114 }
88115
116+ async fn sender_denylist_reload_v2 (
117+ pgpool : PgPool ,
118+ denylist_rwlock : Arc < RwLock < HashSet < Address > > > ,
119+ ) -> anyhow:: Result < ( ) > {
120+ // Fetch the denylist from the DB
121+ let sender_denylist = sqlx:: query!(
122+ r#"
123+ SELECT sender_address FROM tap_horizon_denylist
124+ "#
125+ )
126+ . fetch_all ( & pgpool)
127+ . await ?
128+ . iter ( )
129+ . map ( |row| Address :: from_str ( & row. sender_address ) )
130+ . collect :: < Result < HashSet < _ > , _ > > ( ) ?;
131+
132+ * ( denylist_rwlock. write ( ) . unwrap ( ) ) = sender_denylist;
133+
134+ Ok ( ( ) )
135+ }
136+
89137 async fn sender_denylist_watcher (
90138 pgpool : PgPool ,
91139 mut pglistener : PgListener ,
92140 denylist : Arc < RwLock < HashSet < Address > > > ,
93141 cancel_token : tokio_util:: sync:: CancellationToken ,
142+ version : DenyListVersion ,
94143 #[ cfg( test) ] notify : std:: sync:: Arc < tokio:: sync:: Notify > ,
95144 ) {
96145 #[ derive( serde:: Deserialize ) ]
@@ -137,10 +186,14 @@ impl DenyListCheck {
137186 denylist.",
138187 denylist_notification. tg_op
139188 ) ;
140-
141- Self :: sender_denylist_reload( pgpool. clone( ) , denylist. clone( ) )
142- . await
143- . expect( "should be able to reload the sender denylist" )
189+ match version {
190+ DenyListVersion :: V1 => Self :: sender_denylist_reload_v1( pgpool. clone( ) , denylist. clone( ) )
191+ . await
192+ . expect( "should be able to reload the sender denylist" ) ,
193+ DenyListVersion :: V2 => Self :: sender_denylist_reload_v2( pgpool. clone( ) , denylist. clone( ) )
194+ . await
195+ . expect( "should be able to reload the sender denylist" ) ,
196+ }
144197 }
145198 }
146199 #[ cfg( test) ]
@@ -153,18 +206,30 @@ impl DenyListCheck {
153206
154207#[ async_trait:: async_trait]
155208impl Check < TapReceipt > for DenyListCheck {
156- async fn check ( & self , ctx : & tap_core:: receipt:: Context , _: & CheckingReceipt ) -> CheckResult {
209+ async fn check (
210+ & self ,
211+ ctx : & tap_core:: receipt:: Context ,
212+ receipt : & CheckingReceipt ,
213+ ) -> CheckResult {
157214 let Sender ( receipt_sender) = ctx
158215 . get :: < Sender > ( )
159216 . ok_or ( CheckError :: Failed ( anyhow:: anyhow!( "Could not find sender" ) ) ) ?;
160217
218+ let denied = match receipt. signed_receipt ( ) {
219+ TapReceipt :: V1 ( _) => self
220+ . sender_denylist_v1
221+ . read ( )
222+ . unwrap ( )
223+ . contains ( receipt_sender) ,
224+ TapReceipt :: V2 ( _) => self
225+ . sender_denylist_v2
226+ . read ( )
227+ . unwrap ( )
228+ . contains ( receipt_sender) ,
229+ } ;
230+
161231 // Check that the sender is not denylisted
162- if self
163- . sender_denylist
164- . read ( )
165- . unwrap ( )
166- . contains ( receipt_sender)
167- {
232+ if denied {
168233 return Err ( CheckError :: Failed ( anyhow:: anyhow!(
169234 "Received a receipt from a denylisted sender: {}" ,
170235 receipt_sender
0 commit comments