22// SPDX-License-Identifier: Apache-2.0
33
44use axum_extra:: headers:: { self , Header , HeaderName , HeaderValue } ;
5+ use base64:: prelude:: * ;
56use lazy_static:: lazy_static;
67use prometheus:: { register_counter, Counter } ;
8+ use prost:: Message ;
9+ use tap_aggregator:: grpc;
710use tap_graph:: SignedReceipt ;
811
912use crate :: tap:: TapReceipt ;
@@ -26,17 +29,28 @@ impl Header for TapHeader {
2629 where
2730 I : Iterator < Item = & ' i HeaderValue > ,
2831 {
29- let mut execute = || {
30- let value = values. next ( ) ;
31- let raw_receipt = value. ok_or ( headers:: Error :: invalid ( ) ) ?;
32- let raw_receipt = raw_receipt
33- . to_str ( )
34- . map_err ( |_| headers:: Error :: invalid ( ) ) ?;
35- let parsed_receipt: SignedReceipt =
36- serde_json:: from_str ( raw_receipt) . map_err ( |_| headers:: Error :: invalid ( ) ) ?;
37- Ok ( TapHeader ( crate :: tap:: TapReceipt :: V1 ( parsed_receipt) ) )
32+ let mut execute = || -> anyhow:: Result < TapHeader > {
33+ let raw_receipt = values. next ( ) . ok_or ( headers:: Error :: invalid ( ) ) ?;
34+
35+ // we first try to decode a v2 receipt since it's cheaper and fail earlier than using
36+ // serde
37+ match BASE64_STANDARD . decode ( raw_receipt) {
38+ Ok ( raw_receipt) => {
39+ tracing:: debug!( "Decoded v2" ) ;
40+ let receipt = grpc:: v2:: SignedReceipt :: decode ( raw_receipt. as_ref ( ) ) ?;
41+ Ok ( TapHeader ( TapReceipt :: V2 ( receipt. try_into ( ) ?) ) )
42+ }
43+ Err ( _) => {
44+ tracing:: debug!( "Could not decode v2, trying v1" ) ;
45+ let parsed_receipt: SignedReceipt =
46+ serde_json:: from_slice ( raw_receipt. as_ref ( ) ) ?;
47+ Ok ( TapHeader ( TapReceipt :: V1 ( parsed_receipt) ) )
48+ }
49+ }
3850 } ;
39- execute ( ) . inspect_err ( |_| TAP_RECEIPT_INVALID . inc ( ) )
51+ execute ( )
52+ . map_err ( |_| headers:: Error :: invalid ( ) )
53+ . inspect_err ( |_| TAP_RECEIPT_INVALID . inc ( ) )
4054 }
4155
4256 fn encode < E > ( & self , _values : & mut E )
@@ -51,13 +65,16 @@ impl Header for TapHeader {
5165mod test {
5266 use axum:: http:: HeaderValue ;
5367 use axum_extra:: headers:: Header ;
54- use test_assets:: { create_signed_receipt, SignedReceiptRequest } ;
68+ use base64:: prelude:: * ;
69+ use prost:: Message ;
70+ use tap_aggregator:: grpc:: v2:: SignedReceipt ;
71+ use test_assets:: { create_signed_receipt, create_signed_receipt_v2, SignedReceiptRequest } ;
5572
5673 use super :: TapHeader ;
5774 use crate :: tap:: TapReceipt ;
5875
5976 #[ tokio:: test]
60- async fn test_decode_valid_tap_receipt_header ( ) {
77+ async fn test_decode_valid_tap_v1_receipt_header ( ) {
6178 let original_receipt = create_signed_receipt ( SignedReceiptRequest :: builder ( ) . build ( ) ) . await ;
6279 let serialized_receipt = serde_json:: to_string ( & original_receipt) . unwrap ( ) ;
6380 let header_value = HeaderValue :: from_str ( & serialized_receipt) . unwrap ( ) ;
@@ -68,6 +85,20 @@ mod test {
6885 assert_eq ! ( decoded_receipt, TapHeader ( TapReceipt :: V1 ( original_receipt) ) ) ;
6986 }
7087
88+ #[ test_log:: test( tokio:: test) ]
89+ async fn test_decode_valid_tap_v2_receipt_header ( ) {
90+ let original_receipt = create_signed_receipt_v2 ( ) . call ( ) . await ;
91+ let protobuf_receipt = SignedReceipt :: from ( original_receipt. clone ( ) ) ;
92+ let encoded = protobuf_receipt. encode_to_vec ( ) ;
93+ let base64_encoded = BASE64_STANDARD . encode ( encoded) ;
94+ let header_value = HeaderValue :: from_str ( & base64_encoded) . unwrap ( ) ;
95+ let header_values = vec ! [ & header_value] ;
96+ let decoded_receipt = TapHeader :: decode ( & mut header_values. into_iter ( ) )
97+ . expect ( "tap receipt header value should be valid" ) ;
98+
99+ assert_eq ! ( decoded_receipt, TapHeader ( TapReceipt :: V2 ( original_receipt) ) ) ;
100+ }
101+
71102 #[ test]
72103 fn test_decode_non_string_tap_receipt_header ( ) {
73104 let header_value = HeaderValue :: from_static ( "123" ) ;
0 commit comments