@@ -8,3 +8,99 @@ mod tap;
88pub use bearer:: Bearer ;
99pub use or:: OrExt ;
1010pub use tap:: tap_receipt_authorize;
11+
12+ #[ cfg( test) ]
13+ mod tests {
14+ use std:: time:: Duration ;
15+
16+ use alloy:: primitives:: { address, Address } ;
17+ use axum:: body:: Body ;
18+ use axum:: http:: { Request , Response } ;
19+ use reqwest:: { header, StatusCode } ;
20+ use sqlx:: PgPool ;
21+ use tap_core:: { manager:: Manager , receipt:: checks:: CheckList } ;
22+ use tower:: { Service , ServiceBuilder , ServiceExt } ;
23+ use tower_http:: auth:: AsyncRequireAuthorizationLayer ;
24+
25+ use crate :: middleware:: auth:: { self , Bearer , OrExt } ;
26+ use crate :: tap:: IndexerTapContext ;
27+ use test_assets:: { create_signed_receipt, TAP_EIP712_DOMAIN } ;
28+
29+ const ALLOCATION_ID : Address = address ! ( "deadbeefcafebabedeadbeefcafebabedeadbeef" ) ;
30+
31+ async fn handle ( _: Request < Body > ) -> anyhow:: Result < Response < Body > > {
32+ Ok ( Response :: new ( Body :: default ( ) ) )
33+ }
34+
35+ #[ sqlx:: test( migrations = "../../migrations" ) ]
36+ async fn test_middleware_composition ( pgpool : PgPool ) {
37+ let token = "test" . to_string ( ) ;
38+ let context = IndexerTapContext :: new ( pgpool. clone ( ) , TAP_EIP712_DOMAIN . clone ( ) ) . await ;
39+ let tap_manager = Box :: leak ( Box :: new ( Manager :: new (
40+ TAP_EIP712_DOMAIN . clone ( ) ,
41+ context,
42+ CheckList :: empty ( ) ,
43+ ) ) ) ;
44+ let metric = Box :: leak ( Box :: new (
45+ prometheus:: register_counter_vec!(
46+ "merge_checks_test" ,
47+ "Failed queries to handler" ,
48+ & [ "deployment" ]
49+ )
50+ . unwrap ( ) ,
51+ ) ) ;
52+ let free_query = Bearer :: new ( & token) ;
53+ let tap_auth = auth:: tap_receipt_authorize ( tap_manager, metric) ;
54+ let authorize_requests = free_query. or ( tap_auth) ;
55+
56+ let authorization_middleware = AsyncRequireAuthorizationLayer :: new ( authorize_requests) ;
57+
58+ let mut service = ServiceBuilder :: new ( )
59+ . layer ( authorization_middleware)
60+ . service_fn ( handle) ;
61+
62+ let handle = service. ready ( ) . await . unwrap ( ) ;
63+
64+ // should allow queries that contains the free token
65+ // if the token does not match, return payment required
66+ let mut req = Request :: new ( Default :: default ( ) ) ;
67+ req. headers_mut ( ) . insert (
68+ header:: AUTHORIZATION ,
69+ format ! ( "Bearer {token}" ) . parse ( ) . unwrap ( ) ,
70+ ) ;
71+ let res = handle. call ( req) . await . unwrap ( ) ;
72+ assert_eq ! ( res. status( ) , StatusCode :: OK ) ;
73+
74+ // if the token exists but is wrong, try the receipt
75+ let mut req = Request :: new ( Default :: default ( ) ) ;
76+ req. headers_mut ( )
77+ . insert ( header:: AUTHORIZATION , "Bearer wrongtoken" . parse ( ) . unwrap ( ) ) ;
78+ let res = handle. call ( req) . await . unwrap ( ) ;
79+ // we return the error from tap
80+ assert_eq ! ( res. status( ) , StatusCode :: PAYMENT_REQUIRED ) ;
81+
82+ let receipt = create_signed_receipt ( ALLOCATION_ID , 1 , 1 , 1 ) . await ;
83+
84+ // check with receipt
85+ let mut req = Request :: new ( Default :: default ( ) ) ;
86+ req. extensions_mut ( ) . insert ( receipt) ;
87+ let res = handle. call ( req) . await . unwrap ( ) ;
88+ assert_eq ! ( res. status( ) , StatusCode :: OK ) ;
89+
90+ // todo make this sleep better
91+ tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
92+
93+ // verify receipts
94+ let result = sqlx:: query!( "SELECT * FROM scalar_tap_receipts" )
95+ . fetch_all ( & pgpool)
96+ . await
97+ . unwrap ( ) ;
98+ assert_eq ! ( result. len( ) , 1 ) ;
99+
100+ // if it has neither, should return unauthorized
101+ // check no headers
102+ let req = Request :: new ( Default :: default ( ) ) ;
103+ let res = handle. call ( req) . await . unwrap ( ) ;
104+ assert_eq ! ( res. status( ) , StatusCode :: PAYMENT_REQUIRED ) ;
105+ }
106+ }
0 commit comments