5
5
use std:: error:: Error ;
6
6
use std:: io:: Cursor ;
7
7
8
+ use log:: { debug, warn} ;
9
+
8
10
use prio:: vdaf:: prio3:: Prio3Sum ;
9
11
use prio:: vdaf:: prio3:: Prio3SumVec ;
10
12
use thin_vec:: ThinVec ;
@@ -19,8 +21,6 @@ use types::Time;
19
21
20
22
use prio:: codec:: Encode ;
21
23
use prio:: codec:: { decode_u16_items, encode_u32_items} ;
22
- use prio:: flp:: types:: { Sum , SumVec } ;
23
- use prio:: vdaf:: prio3:: Prio3 ;
24
24
use prio:: vdaf:: Client ;
25
25
use prio:: vdaf:: VdafError ;
26
26
@@ -41,25 +41,24 @@ extern "C" {
41
41
) -> bool ;
42
42
}
43
43
44
- pub fn new_prio_u8 ( num_aggregators : u8 , bits : u32 ) -> Result < Prio3Sum , VdafError > {
44
+ pub fn new_prio_sum ( num_aggregators : u8 , bits : usize ) -> Result < Prio3Sum , VdafError > {
45
45
if bits > 64 {
46
46
return Err ( VdafError :: Uncategorized ( format ! (
47
47
"bit length ({}) exceeds limit for aggregate type (64)" ,
48
48
bits
49
49
) ) ) ;
50
50
}
51
51
52
- Prio3 :: new ( num_aggregators, Sum :: new ( bits as usize ) ? )
52
+ Prio3Sum :: new_sum ( num_aggregators, bits)
53
53
}
54
54
55
- pub fn new_prio_vecu8 ( num_aggregators : u8 , len : usize ) -> Result < Prio3SumVec , VdafError > {
55
+ pub fn new_prio_sumvec (
56
+ num_aggregators : u8 ,
57
+ len : usize ,
58
+ bits : usize ,
59
+ ) -> Result < Prio3SumVec , VdafError > {
56
60
let chunk_length = prio:: vdaf:: prio3:: optimal_chunk_length ( 8 * len) ;
57
- Prio3 :: new ( num_aggregators, SumVec :: new ( 8 , len, chunk_length) ?)
58
- }
59
-
60
- pub fn new_prio_vecu16 ( num_aggregators : u8 , len : usize ) -> Result < Prio3SumVec , VdafError > {
61
- let chunk_length = prio:: vdaf:: prio3:: optimal_chunk_length ( 16 * len) ;
62
- Prio3 :: new ( num_aggregators, SumVec :: new ( 16 , len, chunk_length) ?)
61
+ Prio3SumVec :: new_sum_vec ( num_aggregators, bits, len, chunk_length)
63
62
}
64
63
65
64
enum Role {
@@ -112,14 +111,17 @@ impl Shardable for u8 {
112
111
& self ,
113
112
nonce : & [ u8 ; 16 ] ,
114
113
) -> Result < ( Vec < u8 > , Vec < Vec < u8 > > ) , Box < dyn std:: error:: Error > > {
115
- let prio = new_prio_u8 ( 2 , 2 ) ?;
114
+ let prio = new_prio_sum ( 2 , 8 ) ?;
116
115
117
116
let ( public_share, input_shares) = prio. shard ( & ( * self as u128 ) , nonce) ?;
118
117
119
118
debug_assert_eq ! ( input_shares. len( ) , 2 ) ;
120
119
121
- let encoded_input_shares = input_shares. iter ( ) . map ( |s| s. get_encoded ( ) ) . collect ( ) ;
122
- let encoded_public_share = public_share. get_encoded ( ) ;
120
+ let encoded_input_shares = input_shares
121
+ . iter ( )
122
+ . map ( |s| s. get_encoded ( ) )
123
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
124
+ let encoded_public_share = public_share. get_encoded ( ) ?;
123
125
Ok ( ( encoded_public_share, encoded_input_shares) )
124
126
}
125
127
}
@@ -129,15 +131,18 @@ impl Shardable for ThinVec<u8> {
129
131
& self ,
130
132
nonce : & [ u8 ; 16 ] ,
131
133
) -> Result < ( Vec < u8 > , Vec < Vec < u8 > > ) , Box < dyn std:: error:: Error > > {
132
- let prio = new_prio_vecu8 ( 2 , self . len ( ) ) ?;
134
+ let prio = new_prio_sumvec ( 2 , self . len ( ) , 8 ) ?;
133
135
134
136
let measurement: Vec < u128 > = self . iter ( ) . map ( |e| ( * e as u128 ) ) . collect ( ) ;
135
137
let ( public_share, input_shares) = prio. shard ( & measurement, nonce) ?;
136
138
137
139
debug_assert_eq ! ( input_shares. len( ) , 2 ) ;
138
140
139
- let encoded_input_shares = input_shares. iter ( ) . map ( |s| s. get_encoded ( ) ) . collect ( ) ;
140
- let encoded_public_share = public_share. get_encoded ( ) ;
141
+ let encoded_input_shares = input_shares
142
+ . iter ( )
143
+ . map ( |s| s. get_encoded ( ) )
144
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
145
+ let encoded_public_share = public_share. get_encoded ( ) ?;
141
146
Ok ( ( encoded_public_share, encoded_input_shares) )
142
147
}
143
148
}
@@ -147,23 +152,26 @@ impl Shardable for ThinVec<u16> {
147
152
& self ,
148
153
nonce : & [ u8 ; 16 ] ,
149
154
) -> Result < ( Vec < u8 > , Vec < Vec < u8 > > ) , Box < dyn std:: error:: Error > > {
150
- let prio = new_prio_vecu16 ( 2 , self . len ( ) ) ?;
155
+ let prio = new_prio_sumvec ( 2 , self . len ( ) , 16 ) ?;
151
156
152
157
let measurement: Vec < u128 > = self . iter ( ) . map ( |e| ( * e as u128 ) ) . collect ( ) ;
153
158
let ( public_share, input_shares) = prio. shard ( & measurement, nonce) ?;
154
159
155
160
debug_assert_eq ! ( input_shares. len( ) , 2 ) ;
156
161
157
- let encoded_input_shares = input_shares. iter ( ) . map ( |s| s. get_encoded ( ) ) . collect ( ) ;
158
- let encoded_public_share = public_share. get_encoded ( ) ;
162
+ let encoded_input_shares = input_shares
163
+ . iter ( )
164
+ . map ( |s| s. get_encoded ( ) )
165
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
166
+ let encoded_public_share = public_share. get_encoded ( ) ?;
159
167
Ok ( ( encoded_public_share, encoded_input_shares) )
160
168
}
161
169
}
162
170
163
171
/// Pre-fill the info part of the HPKE sealing with the constants from the standard.
164
172
fn make_base_info ( ) -> Vec < u8 > {
165
173
let mut info = Vec :: < u8 > :: new ( ) ;
166
- const START : & [ u8 ] = "dap-07 input share" . as_bytes ( ) ;
174
+ const START : & [ u8 ] = "dap-09 input share" . as_bytes ( ) ;
167
175
info. extend ( START ) ;
168
176
const FIXED : u8 = 1 ;
169
177
info. push ( FIXED ) ;
@@ -215,7 +223,8 @@ fn get_dap_report_internal<T: Shardable>(
215
223
}
216
224
. get_encoded ( )
217
225
} )
218
- . collect ( ) ;
226
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
227
+ debug ! ( "Plaintext input shares computed." ) ;
219
228
220
229
let metadata = ReportMetadata {
221
230
report_id,
@@ -230,18 +239,20 @@ fn get_dap_report_internal<T: Shardable>(
230
239
let mut info = make_base_info ( ) ;
231
240
232
241
let mut aad = Vec :: from ( * task_id) ;
233
- metadata. encode ( & mut aad) ;
234
- encode_u32_items ( & mut aad, & ( ) , & encoded_public_share) ;
242
+ metadata. encode ( & mut aad) ? ;
243
+ encode_u32_items ( & mut aad, & ( ) , & encoded_public_share) ? ;
235
244
236
245
info. push ( Role :: Leader as u8 ) ;
237
246
238
247
let leader_payload =
239
248
hpke_encrypt_wrapper ( & plaintext_input_shares[ 0 ] , & aad, & info, & leader_hpke_config) ?;
249
+ debug ! ( "Leader payload encrypted." ) ;
240
250
241
251
* info. last_mut ( ) . unwrap ( ) = Role :: Helper as u8 ;
242
252
243
253
let helper_payload =
244
254
hpke_encrypt_wrapper ( & plaintext_input_shares[ 1 ] , & aad, & info, & helper_hpke_config) ?;
255
+ debug ! ( "Helper payload encrypted." ) ;
245
256
246
257
Ok ( Report {
247
258
metadata,
@@ -264,20 +275,22 @@ pub extern "C" fn dapGetReportU8(
264
275
) -> bool {
265
276
assert_eq ! ( task_id. len( ) , 32 ) ;
266
277
267
- if let Ok ( report) = get_dap_report_internal :: < u8 > (
278
+ let Ok ( report) = get_dap_report_internal :: < u8 > (
268
279
leader_hpke_config_encoded,
269
280
helper_hpke_config_encoded,
270
281
& measurement,
271
282
& task_id. as_slice ( ) . try_into ( ) . unwrap ( ) ,
272
283
time_precision,
273
- ) {
274
- let encoded_report = report. get_encoded ( ) ;
275
- out_report. extend ( encoded_report) ;
276
-
277
- true
278
- } else {
279
- false
280
- }
284
+ ) else {
285
+ warn ! ( "Creating report failed!" ) ;
286
+ return false ;
287
+ } ;
288
+ let Ok ( encoded_report) = report. get_encoded ( ) else {
289
+ warn ! ( "Encoding report failed!" ) ;
290
+ return false ;
291
+ } ;
292
+ out_report. extend ( encoded_report) ;
293
+ true
281
294
}
282
295
283
296
#[ no_mangle]
@@ -291,20 +304,22 @@ pub extern "C" fn dapGetReportVecU8(
291
304
) -> bool {
292
305
assert_eq ! ( task_id. len( ) , 32 ) ;
293
306
294
- if let Ok ( report) = get_dap_report_internal :: < ThinVec < u8 > > (
307
+ let Ok ( report) = get_dap_report_internal :: < ThinVec < u8 > > (
295
308
leader_hpke_config_encoded,
296
309
helper_hpke_config_encoded,
297
310
measurement,
298
311
& task_id. as_slice ( ) . try_into ( ) . unwrap ( ) ,
299
312
time_precision,
300
- ) {
301
- let encoded_report = report. get_encoded ( ) ;
302
- out_report. extend ( encoded_report) ;
303
-
304
- true
305
- } else {
306
- false
307
- }
313
+ ) else {
314
+ warn ! ( "Creating report failed!" ) ;
315
+ return false ;
316
+ } ;
317
+ let Ok ( encoded_report) = report. get_encoded ( ) else {
318
+ warn ! ( "Encoding report failed!" ) ;
319
+ return false ;
320
+ } ;
321
+ out_report. extend ( encoded_report) ;
322
+ true
308
323
}
309
324
310
325
#[ no_mangle]
@@ -318,18 +333,20 @@ pub extern "C" fn dapGetReportVecU16(
318
333
) -> bool {
319
334
assert_eq ! ( task_id. len( ) , 32 ) ;
320
335
321
- if let Ok ( report) = get_dap_report_internal :: < ThinVec < u16 > > (
336
+ let Ok ( report) = get_dap_report_internal :: < ThinVec < u16 > > (
322
337
leader_hpke_config_encoded,
323
338
helper_hpke_config_encoded,
324
339
measurement,
325
340
& task_id. as_slice ( ) . try_into ( ) . unwrap ( ) ,
326
341
time_precision,
327
- ) {
328
- let encoded_report = report. get_encoded ( ) ;
329
- out_report. extend ( encoded_report) ;
330
-
331
- true
332
- } else {
333
- false
334
- }
342
+ ) else {
343
+ warn ! ( "Creating report failed!" ) ;
344
+ return false ;
345
+ } ;
346
+ let Ok ( encoded_report) = report. get_encoded ( ) else {
347
+ warn ! ( "Encoding report failed!" ) ;
348
+ return false ;
349
+ } ;
350
+ out_report. extend ( encoded_report) ;
351
+ true
335
352
}
0 commit comments