@@ -11,15 +11,79 @@ use crate::{
1111 constants:: DapMediaType ,
1212 error:: DapAbort ,
1313 messages:: {
14- constant_time_eq,
15- request:: { AggregationJobRequestHash , HashedAggregationJobReq } ,
16- AggregateShare , AggregateShareReq , AggregationJobId , PartialBatchSelector , TaskId ,
14+ constant_time_eq, AggregateShare , AggregateShareReq , AggregationJobId ,
15+ AggregationJobInitReq , PartialBatchSelector , TaskId ,
1716 } ,
1817 metrics:: { DaphneRequestType , ReportStatus } ,
1918 protocol:: aggregator:: ReplayProtection ,
2019 DapAggregationParam , DapError , DapRequest , DapResponse , DapTaskConfig , DapVersion ,
2120} ;
2221
22+ #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
23+ #[ cfg_attr( any( test, feature = "test-utils" ) , derive( deepsize:: DeepSizeOf ) ) ]
24+ pub struct AggregationJobRequestHash ( Vec < u8 > ) ;
25+
26+ impl AggregationJobRequestHash {
27+ pub fn get ( & self ) -> & [ u8 ] {
28+ & self . 0
29+ }
30+
31+ fn hash ( bytes : & [ u8 ] ) -> Self {
32+ Self (
33+ ring:: digest:: digest ( & ring:: digest:: SHA256 , bytes)
34+ . as_ref ( )
35+ . to_vec ( ) ,
36+ )
37+ }
38+ }
39+
40+ /// An [`AggregationJobInitReq`] and its hash. Used by the helper to prevent the parameters of an
41+ /// aggregation job from changing.
42+ pub struct HashedAggregationJobReq {
43+ pub request : AggregationJobInitReq ,
44+ pub hash : AggregationJobRequestHash ,
45+ }
46+
47+ impl HashedAggregationJobReq {
48+ #[ cfg( any( test, feature = "test-utils" ) ) ]
49+ pub fn from_aggregation_req ( version : DapVersion , request : AggregationJobInitReq ) -> Self {
50+ let mut buf = Vec :: new ( ) ;
51+ request. encode_with_param ( & version, & mut buf) . unwrap ( ) ;
52+ Self {
53+ request,
54+ hash : AggregationJobRequestHash :: hash ( & buf) ,
55+ }
56+ }
57+ }
58+
59+ impl ParameterizedEncode < DapVersion > for HashedAggregationJobReq {
60+ fn encode_with_param (
61+ & self ,
62+ encoding_parameter : & DapVersion ,
63+ bytes : & mut Vec < u8 > ,
64+ ) -> Result < ( ) , prio:: codec:: CodecError > {
65+ self . request . encode_with_param ( encoding_parameter, bytes)
66+ }
67+ }
68+
69+ impl ParameterizedDecode < DapVersion > for HashedAggregationJobReq {
70+ fn decode_with_param (
71+ decoding_parameter : & DapVersion ,
72+ bytes : & mut std:: io:: Cursor < & [ u8 ] > ,
73+ ) -> Result < Self , prio:: codec:: CodecError > {
74+ let start = usize:: try_from ( bytes. position ( ) )
75+ . map_err ( |e| prio:: codec:: CodecError :: Other ( Box :: new ( e) ) ) ?;
76+ let request = AggregationJobInitReq :: decode_with_param ( decoding_parameter, bytes) ?;
77+ let end = usize:: try_from ( bytes. position ( ) )
78+ . map_err ( |e| prio:: codec:: CodecError :: Other ( Box :: new ( e) ) ) ?;
79+
80+ Ok ( Self {
81+ request,
82+ hash : AggregationJobRequestHash :: hash ( & bytes. get_ref ( ) [ start..end] ) ,
83+ } )
84+ }
85+ }
86+
2387/// DAP Helper functionality.
2488#[ async_trait]
2589pub trait DapHelper : DapAggregator {
0 commit comments