44use axum_extra:: headers:: { self , Header , HeaderName , HeaderValue } ;
55use lazy_static:: lazy_static;
66use prometheus:: { register_counter, Counter } ;
7- use tap_core:: receipt:: SignedReceipt ;
7+ use serde:: de;
8+ use serde_json:: Value ;
9+ use tap_core:: receipt:: SignedReceipt as SignedReceiptV1 ;
10+ use tap_core_v2:: receipt:: SignedReceipt as SignedReceiptV2 ;
811
9- #[ derive( Debug , PartialEq ) ]
10- pub struct TapReceipt ( pub SignedReceipt ) ;
12+ #[ derive( Debug , PartialEq , Clone , serde:: Serialize ) ]
13+ #[ serde( untagged) ]
14+ pub enum TapReceipt {
15+ V1 ( SignedReceiptV1 ) ,
16+ V2 ( SignedReceiptV2 ) ,
17+ }
18+
19+ impl < ' de > serde:: Deserialize < ' de > for TapReceipt {
20+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
21+ where
22+ D : serde:: Deserializer < ' de > ,
23+ {
24+ let temp = Value :: deserialize ( deserializer) ?;
25+
26+ let is_v1 = temp
27+ . as_object ( )
28+ . ok_or ( de:: Error :: custom ( "Didn't receive an object" ) ) ?
29+ . get ( "message" )
30+ . ok_or ( de:: Error :: custom ( "There's no message in the object" ) ) ?
31+ . as_object ( )
32+ . ok_or ( de:: Error :: custom ( "Message is not an object" ) ) ?
33+ . contains_key ( "allocation_id" ) ;
34+
35+ if is_v1 {
36+ // Try V1 first
37+ serde_json:: from_value :: < SignedReceiptV1 > ( temp)
38+ . map ( TapReceipt :: V1 )
39+ . map_err ( de:: Error :: custom)
40+ } else {
41+ // Fall back to V2
42+ serde_json:: from_value :: < SignedReceiptV2 > ( temp)
43+ . map ( TapReceipt :: V2 )
44+ . map_err ( de:: Error :: custom)
45+ }
46+ }
47+ }
48+
49+ impl From < SignedReceiptV1 > for TapReceipt {
50+ fn from ( value : SignedReceiptV1 ) -> Self {
51+ Self :: V1 ( value)
52+ }
53+ }
54+
55+ impl From < SignedReceiptV2 > for TapReceipt {
56+ fn from ( value : SignedReceiptV2 ) -> Self {
57+ Self :: V2 ( value)
58+ }
59+ }
60+
61+ impl TryFrom < TapReceipt > for SignedReceiptV1 {
62+ type Error = anyhow:: Error ;
63+
64+ fn try_from ( value : TapReceipt ) -> Result < Self , Self :: Error > {
65+ match value {
66+ TapReceipt :: V2 ( _) => Err ( anyhow:: anyhow!( "TapReceipt is V2" ) ) ,
67+ TapReceipt :: V1 ( receipt) => Ok ( receipt) ,
68+ }
69+ }
70+ }
71+
72+ impl TryFrom < TapReceipt > for SignedReceiptV2 {
73+ type Error = anyhow:: Error ;
74+
75+ fn try_from ( value : TapReceipt ) -> Result < Self , Self :: Error > {
76+ match value {
77+ TapReceipt :: V1 ( _) => Err ( anyhow:: anyhow!( "TapReceipt is V1" ) ) ,
78+ TapReceipt :: V2 ( receipt) => Ok ( receipt) ,
79+ }
80+ }
81+ }
1182
1283lazy_static ! {
1384 static ref TAP_RECEIPT : HeaderName = HeaderName :: from_static( "tap-receipt" ) ;
@@ -30,9 +101,9 @@ impl Header for TapReceipt {
30101 let raw_receipt = raw_receipt
31102 . to_str ( )
32103 . map_err ( |_| headers:: Error :: invalid ( ) ) ?;
33- let parsed_receipt =
104+ let parsed_receipt: TapReceipt =
34105 serde_json:: from_str ( raw_receipt) . map_err ( |_| headers:: Error :: invalid ( ) ) ?;
35- Ok ( TapReceipt ( parsed_receipt) )
106+ Ok ( parsed_receipt)
36107 } ;
37108 execute ( ) . inspect_err ( |_| TAP_RECEIPT_INVALID . inc ( ) )
38109 }
@@ -49,20 +120,54 @@ impl Header for TapReceipt {
49120mod test {
50121 use axum:: http:: HeaderValue ;
51122 use axum_extra:: headers:: Header ;
52- use test_assets:: { create_signed_receipt, SignedReceiptRequest } ;
123+ use test_assets:: {
124+ create_signed_receipt, create_signed_receipt_v2, SignedReceiptRequest ,
125+ SignedReceiptV2Request ,
126+ } ;
53127
54128 use super :: TapReceipt ;
55129
56130 #[ tokio:: test]
57131 async fn test_decode_valid_tap_receipt_header ( ) {
58132 let original_receipt = create_signed_receipt ( SignedReceiptRequest :: builder ( ) . build ( ) ) . await ;
59133 let serialized_receipt = serde_json:: to_string ( & original_receipt) . unwrap ( ) ;
60- let header_value = HeaderValue :: from_str ( & serialized_receipt) . unwrap ( ) ;
134+
135+ let original_receipt_v1: TapReceipt = original_receipt. clone ( ) . into ( ) ;
136+ let serialized_receipt_v1 = serde_json:: to_string ( & original_receipt_v1) . unwrap ( ) ;
137+
138+ assert_eq ! ( serialized_receipt, serialized_receipt_v1) ;
139+
140+ println ! ( "Was able to serialize properly: {serialized_receipt_v1:?}" ) ;
141+ let deserialized: TapReceipt = serde_json:: from_str ( & serialized_receipt_v1) . unwrap ( ) ;
142+ println ! ( "Was able to deserialize properly: {deserialized:?}" ) ;
143+ let header_value = HeaderValue :: from_str ( & serialized_receipt_v1) . unwrap ( ) ;
144+ let header_values = vec ! [ & header_value] ;
145+ let decoded_receipt = TapReceipt :: decode ( & mut header_values. into_iter ( ) )
146+ . expect ( "tap receipt header value should be valid" ) ;
147+
148+ assert_eq ! ( decoded_receipt, original_receipt. into( ) ) ;
149+ }
150+
151+ #[ tokio:: test]
152+ async fn test_decode_valid_tap_v2_receipt_header ( ) {
153+ let original_receipt =
154+ create_signed_receipt_v2 ( SignedReceiptV2Request :: builder ( ) . build ( ) ) . await ;
155+ let serialized_receipt = serde_json:: to_string ( & original_receipt) . unwrap ( ) ;
156+
157+ let original_receipt_v1: TapReceipt = original_receipt. clone ( ) . into ( ) ;
158+ let serialized_receipt_v1 = serde_json:: to_string ( & original_receipt_v1) . unwrap ( ) ;
159+
160+ assert_eq ! ( serialized_receipt, serialized_receipt_v1) ;
161+
162+ println ! ( "Was able to serialize properly: {serialized_receipt_v1:?}" ) ;
163+ let deserialized: TapReceipt = serde_json:: from_str ( & serialized_receipt_v1) . unwrap ( ) ;
164+ println ! ( "Was able to deserialize properly: {deserialized:?}" ) ;
165+ let header_value = HeaderValue :: from_str ( & serialized_receipt_v1) . unwrap ( ) ;
61166 let header_values = vec ! [ & header_value] ;
62167 let decoded_receipt = TapReceipt :: decode ( & mut header_values. into_iter ( ) )
63168 . expect ( "tap receipt header value should be valid" ) ;
64169
65- assert_eq ! ( decoded_receipt, TapReceipt ( original_receipt) ) ;
170+ assert_eq ! ( decoded_receipt, original_receipt. into ( ) ) ;
66171 }
67172
68173 #[ test]
0 commit comments