Skip to content
This repository was archived by the owner on Jul 9, 2025. It is now read-only.

Commit 5568556

Browse files
Bug 1892204: Change using code for libprio 0.16 r=tcampbell
Differential Revision: https://phabricator.services.mozilla.com/D207844
1 parent 3e800d5 commit 5568556

File tree

6 files changed

+112
-83
lines changed

6 files changed

+112
-83
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

toolkit/components/telemetry/dap/ffi-gtest/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use prio::codec::{Decode, Encode};
1717
pub extern "C" fn dap_test_encoding() {
1818
let r = Report::new_dummy();
1919
let mut encoded = Vec::<u8>::new();
20-
Report::encode(&r, &mut encoded);
20+
Report::encode(&r, &mut encoded).expect("Report encoding failed!");
2121
let decoded = Report::decode(&mut Cursor::new(&encoded)).expect("Report decoding failed!");
2222
if r != decoded {
2323
println!("Report:");

toolkit/components/telemetry/dap/ffi/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ license = "MPL-2.0"
1111
prio = {version = "0.16.2", default-features = false }
1212
thin-vec = { version = "0.2.1", features = ["gecko-ffi"] }
1313
rand = "0.8"
14+
log = "0.4"

toolkit/components/telemetry/dap/ffi/src/lib.rs

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
use std::error::Error;
66
use std::io::Cursor;
77

8+
use log::{debug, warn};
9+
810
use prio::vdaf::prio3::Prio3Sum;
911
use prio::vdaf::prio3::Prio3SumVec;
1012
use thin_vec::ThinVec;
@@ -19,8 +21,6 @@ use types::Time;
1921

2022
use prio::codec::Encode;
2123
use prio::codec::{decode_u16_items, encode_u32_items};
22-
use prio::flp::types::{Sum, SumVec};
23-
use prio::vdaf::prio3::Prio3;
2424
use prio::vdaf::Client;
2525
use prio::vdaf::VdafError;
2626

@@ -41,25 +41,24 @@ extern "C" {
4141
) -> bool;
4242
}
4343

44-
pub fn new_prio_u8(num_aggregators: u8, bits: u32) -> Result<Prio3Sum, VdafError> {
44+
pub fn new_prio_sum(num_aggregators: u8, bits: usize) -> Result<Prio3Sum, VdafError> {
4545
if bits > 64 {
4646
return Err(VdafError::Uncategorized(format!(
4747
"bit length ({}) exceeds limit for aggregate type (64)",
4848
bits
4949
)));
5050
}
5151

52-
Prio3::new(num_aggregators, Sum::new(bits as usize)?)
52+
Prio3Sum::new_sum(num_aggregators, bits)
5353
}
5454

55-
pub fn new_prio_vecu8(num_aggregators: u8, len: usize) -> Result<Prio3SumVec, VdafError> {
55+
pub fn new_prio_sumvec(
56+
num_aggregators: u8,
57+
len: usize,
58+
bits: usize,
59+
) -> Result<Prio3SumVec, VdafError> {
5660
let chunk_length = prio::vdaf::prio3::optimal_chunk_length(8 * len);
57-
Prio3::new(num_aggregators, SumVec::new(8, len, chunk_length)?)
58-
}
59-
60-
pub fn new_prio_vecu16(num_aggregators: u8, len: usize) -> Result<Prio3SumVec, VdafError> {
61-
let chunk_length = prio::vdaf::prio3::optimal_chunk_length(16 * len);
62-
Prio3::new(num_aggregators, SumVec::new(16, len, chunk_length)?)
61+
Prio3SumVec::new_sum_vec(num_aggregators, bits, len, chunk_length)
6362
}
6463

6564
enum Role {
@@ -112,14 +111,17 @@ impl Shardable for u8 {
112111
&self,
113112
nonce: &[u8; 16],
114113
) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
115-
let prio = new_prio_u8(2, 2)?;
114+
let prio = new_prio_sum(2, 8)?;
116115

117116
let (public_share, input_shares) = prio.shard(&(*self as u128), nonce)?;
118117

119118
debug_assert_eq!(input_shares.len(), 2);
120119

121-
let encoded_input_shares = input_shares.iter().map(|s| s.get_encoded()).collect();
122-
let encoded_public_share = public_share.get_encoded();
120+
let encoded_input_shares = input_shares
121+
.iter()
122+
.map(|s| s.get_encoded())
123+
.collect::<Result<Vec<_>, _>>()?;
124+
let encoded_public_share = public_share.get_encoded()?;
123125
Ok((encoded_public_share, encoded_input_shares))
124126
}
125127
}
@@ -129,15 +131,18 @@ impl Shardable for ThinVec<u8> {
129131
&self,
130132
nonce: &[u8; 16],
131133
) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
132-
let prio = new_prio_vecu8(2, self.len())?;
134+
let prio = new_prio_sumvec(2, self.len(), 8)?;
133135

134136
let measurement: Vec<u128> = self.iter().map(|e| (*e as u128)).collect();
135137
let (public_share, input_shares) = prio.shard(&measurement, nonce)?;
136138

137139
debug_assert_eq!(input_shares.len(), 2);
138140

139-
let encoded_input_shares = input_shares.iter().map(|s| s.get_encoded()).collect();
140-
let encoded_public_share = public_share.get_encoded();
141+
let encoded_input_shares = input_shares
142+
.iter()
143+
.map(|s| s.get_encoded())
144+
.collect::<Result<Vec<_>, _>>()?;
145+
let encoded_public_share = public_share.get_encoded()?;
141146
Ok((encoded_public_share, encoded_input_shares))
142147
}
143148
}
@@ -147,23 +152,26 @@ impl Shardable for ThinVec<u16> {
147152
&self,
148153
nonce: &[u8; 16],
149154
) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
150-
let prio = new_prio_vecu16(2, self.len())?;
155+
let prio = new_prio_sumvec(2, self.len(), 16)?;
151156

152157
let measurement: Vec<u128> = self.iter().map(|e| (*e as u128)).collect();
153158
let (public_share, input_shares) = prio.shard(&measurement, nonce)?;
154159

155160
debug_assert_eq!(input_shares.len(), 2);
156161

157-
let encoded_input_shares = input_shares.iter().map(|s| s.get_encoded()).collect();
158-
let encoded_public_share = public_share.get_encoded();
162+
let encoded_input_shares = input_shares
163+
.iter()
164+
.map(|s| s.get_encoded())
165+
.collect::<Result<Vec<_>, _>>()?;
166+
let encoded_public_share = public_share.get_encoded()?;
159167
Ok((encoded_public_share, encoded_input_shares))
160168
}
161169
}
162170

163171
/// Pre-fill the info part of the HPKE sealing with the constants from the standard.
164172
fn make_base_info() -> Vec<u8> {
165173
let mut info = Vec::<u8>::new();
166-
const START: &[u8] = "dap-07 input share".as_bytes();
174+
const START: &[u8] = "dap-09 input share".as_bytes();
167175
info.extend(START);
168176
const FIXED: u8 = 1;
169177
info.push(FIXED);
@@ -215,7 +223,8 @@ fn get_dap_report_internal<T: Shardable>(
215223
}
216224
.get_encoded()
217225
})
218-
.collect();
226+
.collect::<Result<Vec<_>, _>>()?;
227+
debug!("Plaintext input shares computed.");
219228

220229
let metadata = ReportMetadata {
221230
report_id,
@@ -230,18 +239,20 @@ fn get_dap_report_internal<T: Shardable>(
230239
let mut info = make_base_info();
231240

232241
let mut aad = Vec::from(*task_id);
233-
metadata.encode(&mut aad);
234-
encode_u32_items(&mut aad, &(), &encoded_public_share);
242+
metadata.encode(&mut aad)?;
243+
encode_u32_items(&mut aad, &(), &encoded_public_share)?;
235244

236245
info.push(Role::Leader as u8);
237246

238247
let leader_payload =
239248
hpke_encrypt_wrapper(&plaintext_input_shares[0], &aad, &info, &leader_hpke_config)?;
249+
debug!("Leader payload encrypted.");
240250

241251
*info.last_mut().unwrap() = Role::Helper as u8;
242252

243253
let helper_payload =
244254
hpke_encrypt_wrapper(&plaintext_input_shares[1], &aad, &info, &helper_hpke_config)?;
255+
debug!("Helper payload encrypted.");
245256

246257
Ok(Report {
247258
metadata,
@@ -264,20 +275,22 @@ pub extern "C" fn dapGetReportU8(
264275
) -> bool {
265276
assert_eq!(task_id.len(), 32);
266277

267-
if let Ok(report) = get_dap_report_internal::<u8>(
278+
let Ok(report) = get_dap_report_internal::<u8>(
268279
leader_hpke_config_encoded,
269280
helper_hpke_config_encoded,
270281
&measurement,
271282
&task_id.as_slice().try_into().unwrap(),
272283
time_precision,
273-
) {
274-
let encoded_report = report.get_encoded();
275-
out_report.extend(encoded_report);
276-
277-
true
278-
} else {
279-
false
280-
}
284+
) else {
285+
warn!("Creating report failed!");
286+
return false;
287+
};
288+
let Ok(encoded_report) = report.get_encoded() else {
289+
warn!("Encoding report failed!");
290+
return false;
291+
};
292+
out_report.extend(encoded_report);
293+
true
281294
}
282295

283296
#[no_mangle]
@@ -291,20 +304,22 @@ pub extern "C" fn dapGetReportVecU8(
291304
) -> bool {
292305
assert_eq!(task_id.len(), 32);
293306

294-
if let Ok(report) = get_dap_report_internal::<ThinVec<u8>>(
307+
let Ok(report) = get_dap_report_internal::<ThinVec<u8>>(
295308
leader_hpke_config_encoded,
296309
helper_hpke_config_encoded,
297310
measurement,
298311
&task_id.as_slice().try_into().unwrap(),
299312
time_precision,
300-
) {
301-
let encoded_report = report.get_encoded();
302-
out_report.extend(encoded_report);
303-
304-
true
305-
} else {
306-
false
307-
}
313+
) else {
314+
warn!("Creating report failed!");
315+
return false;
316+
};
317+
let Ok(encoded_report) = report.get_encoded() else {
318+
warn!("Encoding report failed!");
319+
return false;
320+
};
321+
out_report.extend(encoded_report);
322+
true
308323
}
309324

310325
#[no_mangle]
@@ -318,18 +333,20 @@ pub extern "C" fn dapGetReportVecU16(
318333
) -> bool {
319334
assert_eq!(task_id.len(), 32);
320335

321-
if let Ok(report) = get_dap_report_internal::<ThinVec<u16>>(
336+
let Ok(report) = get_dap_report_internal::<ThinVec<u16>>(
322337
leader_hpke_config_encoded,
323338
helper_hpke_config_encoded,
324339
measurement,
325340
&task_id.as_slice().try_into().unwrap(),
326341
time_precision,
327-
) {
328-
let encoded_report = report.get_encoded();
329-
out_report.extend(encoded_report);
330-
331-
true
332-
} else {
333-
false
334-
}
342+
) else {
343+
warn!("Creating report failed!");
344+
return false;
345+
};
346+
let Ok(encoded_report) = report.get_encoded() else {
347+
warn!("Encoding report failed!");
348+
return false;
349+
};
350+
out_report.extend(encoded_report);
351+
true
335352
}

0 commit comments

Comments
 (0)