@@ -65,3 +65,135 @@ where
6565 }
6666 }
6767}
68+
69+ #[ cfg( test) ]
70+ mod tests {
71+ use std:: { sync:: Arc , time:: Duration } ;
72+
73+ use alloy:: primitives:: { address, Address } ;
74+ use axum:: {
75+ body:: Body ,
76+ http:: { Request , Response } ,
77+ } ;
78+ use prometheus:: core:: Collector ;
79+ use reqwest:: StatusCode ;
80+ use sqlx:: PgPool ;
81+ use tap_core:: {
82+ manager:: Manager ,
83+ receipt:: {
84+ checks:: { Check , CheckError , CheckList , CheckResult } ,
85+ state:: Checking ,
86+ ReceiptWithState ,
87+ } ,
88+ } ;
89+ use test_assets:: { create_signed_receipt, TAP_EIP712_DOMAIN } ;
90+ use tower:: { Service , ServiceBuilder , ServiceExt } ;
91+ use tower_http:: auth:: AsyncRequireAuthorizationLayer ;
92+
93+ use crate :: {
94+ middleware:: {
95+ auth:: tap_receipt_authorize,
96+ metrics:: { MetricLabelProvider , MetricLabels } ,
97+ } ,
98+ tap:: IndexerTapContext ,
99+ } ;
100+
101+ const ALLOCATION_ID : Address = address ! ( "deadbeefcafebabedeadbeefcafebabedeadbeef" ) ;
102+
103+ async fn handle ( _: Request < Body > ) -> anyhow:: Result < Response < Body > > {
104+ Ok ( Response :: new ( Body :: default ( ) ) )
105+ }
106+
107+ struct TestLabel ;
108+ impl MetricLabelProvider for TestLabel {
109+ fn get_labels ( & self ) -> Vec < & str > {
110+ vec ! [ "label1" ]
111+ }
112+ }
113+
114+ #[ sqlx:: test( migrations = "../../migrations" ) ]
115+ async fn test_tap_middleware ( pgpool : PgPool ) {
116+ let context = IndexerTapContext :: new ( pgpool. clone ( ) , TAP_EIP712_DOMAIN . clone ( ) ) . await ;
117+
118+ struct MyCheck ;
119+ #[ async_trait:: async_trait]
120+ impl Check for MyCheck {
121+ async fn check (
122+ & self ,
123+ _: & tap_core:: receipt:: Context ,
124+ receipt : & ReceiptWithState < Checking > ,
125+ ) -> CheckResult {
126+ if receipt. signed_receipt ( ) . message . nonce == 99 {
127+ Err ( CheckError :: Failed ( anyhow:: anyhow!( "Failed" ) ) )
128+ } else {
129+ Ok ( ( ) )
130+ }
131+ }
132+ }
133+
134+ let tap_manager = Box :: leak ( Box :: new ( Manager :: new (
135+ TAP_EIP712_DOMAIN . clone ( ) ,
136+ context,
137+ CheckList :: new ( vec ! [ Arc :: new( MyCheck ) ] ) ,
138+ ) ) ) ;
139+ let metric = Box :: leak ( Box :: new (
140+ prometheus:: register_counter_vec!(
141+ "test1" ,
142+ "Failed queries to handler" ,
143+ & [ "deployment" ]
144+ )
145+ . unwrap ( ) ,
146+ ) ) ;
147+
148+ let tap_auth = tap_receipt_authorize ( tap_manager, metric) ;
149+
150+ let authorization_middleware = AsyncRequireAuthorizationLayer :: new ( tap_auth) ;
151+
152+ let mut service = ServiceBuilder :: new ( )
153+ . layer ( authorization_middleware)
154+ . service_fn ( handle) ;
155+
156+ let handle = service. ready ( ) . await . unwrap ( ) ;
157+
158+ let receipt = create_signed_receipt ( ALLOCATION_ID , 1 , 1 , 1 ) . await ;
159+
160+ // check with receipt
161+ let mut req = Request :: new ( Default :: default ( ) ) ;
162+ req. extensions_mut ( ) . insert ( receipt) ;
163+ let res = handle. call ( req) . await . unwrap ( ) ;
164+ assert_eq ! ( res. status( ) , StatusCode :: OK ) ;
165+
166+ // todo make this sleep better
167+ tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
168+
169+ // verify receipts
170+ let result = sqlx:: query!( "SELECT * FROM scalar_tap_receipts" )
171+ . fetch_all ( & pgpool)
172+ . await
173+ . unwrap ( ) ;
174+ assert_eq ! ( result. len( ) , 1 ) ;
175+ // if it fails tap receipt, should return failed to process payment + tap message
176+
177+ assert_eq ! ( metric. collect( ) . first( ) . unwrap( ) . get_metric( ) . len( ) , 0 ) ;
178+
179+ // default labels, all empty
180+ let labels: MetricLabels = Arc :: new ( TestLabel ) ;
181+
182+ let mut receipt = create_signed_receipt ( ALLOCATION_ID , 1 , 1 , 1 ) . await ;
183+ // change the nonce to make the receipt invalid
184+ receipt. message . nonce = 99 ;
185+ let mut req = Request :: new ( Default :: default ( ) ) ;
186+ req. extensions_mut ( ) . insert ( receipt) ;
187+ req. extensions_mut ( ) . insert ( labels) ;
188+ let res = handle. call ( req) . await . unwrap ( ) ;
189+ assert_eq ! ( res. status( ) , StatusCode :: BAD_REQUEST ) ;
190+
191+ assert_eq ! ( metric. collect( ) . first( ) . unwrap( ) . get_metric( ) . len( ) , 1 ) ;
192+
193+ // if it doesnt contain the signed receipt
194+ // should return payment required
195+ let req = Request :: new ( Default :: default ( ) ) ;
196+ let res = handle. call ( req) . await . unwrap ( ) ;
197+ assert_eq ! ( res. status( ) , StatusCode :: PAYMENT_REQUIRED ) ;
198+ }
199+ }
0 commit comments