@@ -8,6 +8,7 @@ use super::{
88 CollectionJobId , CollectionReq , Report ,
99} ;
1010use crate :: { constants:: DapMediaType , error:: DapAbort , messages:: TaskId , DapVersion } ;
11+ use prio:: codec:: { ParameterizedDecode , ParameterizedEncode } ;
1112
1213pub trait RequestBody {
1314 type ResourceId ;
@@ -24,15 +25,79 @@ macro_rules! impl_req_body {
2425 } ;
2526}
2627
28+ #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
29+ #[ cfg_attr( any( test, feature = "test-utils" ) , derive( deepsize:: DeepSizeOf ) ) ]
30+ pub struct AggregationJobRequestHash ( Vec < u8 > ) ;
31+
32+ impl AggregationJobRequestHash {
33+ pub fn get ( & self ) -> & [ u8 ] {
34+ & self . 0
35+ }
36+
37+ fn hash ( bytes : & [ u8 ] ) -> Self {
38+ Self (
39+ ring:: digest:: digest ( & ring:: digest:: SHA256 , bytes)
40+ . as_ref ( )
41+ . to_vec ( ) ,
42+ )
43+ }
44+ }
45+
46+ pub struct HashedAggregationJobReq {
47+ pub request : AggregationJobInitReq ,
48+ pub hash : AggregationJobRequestHash ,
49+ }
50+
51+ impl HashedAggregationJobReq {
52+ #[ cfg( any( test, feature = "test-utils" ) ) ]
53+ pub fn from_aggregation_req ( version : DapVersion , request : AggregationJobInitReq ) -> Self {
54+ let mut buf = Vec :: new ( ) ;
55+ request. encode_with_param ( & version, & mut buf) . unwrap ( ) ;
56+ Self {
57+ request,
58+ hash : AggregationJobRequestHash :: hash ( & buf) ,
59+ }
60+ }
61+ }
62+
63+ impl ParameterizedEncode < DapVersion > for HashedAggregationJobReq {
64+ fn encode_with_param (
65+ & self ,
66+ encoding_parameter : & DapVersion ,
67+ bytes : & mut Vec < u8 > ,
68+ ) -> Result < ( ) , prio:: codec:: CodecError > {
69+ self . request . encode_with_param ( encoding_parameter, bytes)
70+ }
71+ }
72+
73+ impl ParameterizedDecode < DapVersion > for HashedAggregationJobReq {
74+ fn decode_with_param (
75+ decoding_parameter : & DapVersion ,
76+ bytes : & mut std:: io:: Cursor < & [ u8 ] > ,
77+ ) -> Result < Self , prio:: codec:: CodecError > {
78+ let start = usize:: try_from ( bytes. position ( ) )
79+ . map_err ( |e| prio:: codec:: CodecError :: Other ( Box :: new ( e) ) ) ?;
80+ let request = AggregationJobInitReq :: decode_with_param ( decoding_parameter, bytes) ?;
81+ let end = usize:: try_from ( bytes. position ( ) )
82+ . map_err ( |e| prio:: codec:: CodecError :: Other ( Box :: new ( e) ) ) ?;
83+
84+ Ok ( Self {
85+ request,
86+ hash : AggregationJobRequestHash :: hash ( & bytes. get_ref ( ) [ start..end] ) ,
87+ } )
88+ }
89+ }
90+
2791impl_req_body ! {
28- // body type | id type
29- // --------------------- | ----------------
30- Report | ( )
31- AggregationJobInitReq | AggregationJobId
32- AggregateShareReq | ( )
33- CollectionReq | CollectionJobId
34- CollectionPollReq | CollectionJobId
35- ( ) | ( )
92+ // body type | id type
93+ // --------------------| ----------------
94+ Report | ( )
95+ AggregationJobInitReq | AggregationJobId
96+ HashedAggregationJobReq | AggregationJobId
97+ AggregateShareReq | ( )
98+ CollectionReq | CollectionJobId
99+ CollectionPollReq | CollectionJobId
100+ ( ) | ( )
36101}
37102
38103/// Fields common to all DAP requests.
@@ -74,6 +139,20 @@ pub struct DapRequest<B: RequestBody> {
74139 pub payload : B ,
75140}
76141
142+ impl < B : RequestBody > DapRequest < B > {
143+ pub fn map < F , O > ( self , mapper : F ) -> DapRequest < O >
144+ where
145+ F : FnOnce ( B ) -> O ,
146+ O : RequestBody < ResourceId = B :: ResourceId > ,
147+ {
148+ DapRequest {
149+ meta : self . meta ,
150+ resource_id : self . resource_id ,
151+ payload : mapper ( self . payload ) ,
152+ }
153+ }
154+ }
155+
77156impl < B : RequestBody > AsRef < DapRequestMeta > for DapRequest < B > {
78157 fn as_ref ( & self ) -> & DapRequestMeta {
79158 & self . meta
0 commit comments