Skip to content

Commit 5ee436a

Browse files
committed
Add tests to capnp decoding/encoding
1 parent 888fe77 commit 5ee436a

File tree

4 files changed

+215
-23
lines changed

4 files changed

+215
-23
lines changed

crates/daphne-service-utils/src/capnproto/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,80 @@ pub fn usize_to_capnp_len(u: usize) -> u32 {
268268
u.try_into()
269269
.expect("capnp can't encode more that u32::MAX of something")
270270
}
271+
272+
#[cfg(test)]
273+
pub fn roundtrip_test<T>(before: T)
274+
where
275+
T: CapnprotoPayloadDecode + CapnprotoPayloadEncode + PartialEq + std::fmt::Debug,
276+
{
277+
assert_eq!(
278+
before,
279+
T::decode_from_bytes(&before.encode_to_bytes()).unwrap()
280+
);
281+
}
282+
283+
#[cfg(test)]
284+
mod tests {
285+
use super::*;
286+
use rand::Rng;
287+
288+
#[test]
289+
fn test_u8_array_serialize_deserialize() {
290+
roundtrip_test(rand::thread_rng().gen::<[u8; 32]>());
291+
roundtrip_test(rand::thread_rng().gen::<[u8; 16]>());
292+
}
293+
294+
#[test]
295+
fn test_partial_batch_selector_serialize_deserialize() {
296+
roundtrip_test(PartialBatchSelector::TimeInterval);
297+
298+
roundtrip_test(PartialBatchSelector::LeaderSelectedByBatchId {
299+
batch_id: BatchId(rand::thread_rng().gen()),
300+
});
301+
}
302+
303+
#[test]
304+
fn test_report_error_conversion() {
305+
// cause a compilation error if the variants change
306+
const _: () = {
307+
#[allow(clippy::match_same_arms)]
308+
match messages::ReportError::Reserved {
309+
messages::ReportError::Reserved => (),
310+
messages::ReportError::BatchCollected => (),
311+
messages::ReportError::ReportReplayed => (),
312+
messages::ReportError::ReportDropped => (),
313+
messages::ReportError::HpkeUnknownConfigId => (),
314+
messages::ReportError::HpkeDecryptError => (),
315+
messages::ReportError::VdafPrepError => (),
316+
messages::ReportError::BatchSaturated => (),
317+
messages::ReportError::TaskExpired => (),
318+
messages::ReportError::InvalidMessage => (),
319+
messages::ReportError::ReportTooEarly => (),
320+
messages::ReportError::TaskNotStarted => (),
321+
}
322+
};
323+
let all_errors = vec![
324+
messages::ReportError::Reserved,
325+
messages::ReportError::BatchCollected,
326+
messages::ReportError::ReportReplayed,
327+
messages::ReportError::ReportDropped,
328+
messages::ReportError::HpkeUnknownConfigId,
329+
messages::ReportError::HpkeDecryptError,
330+
messages::ReportError::VdafPrepError,
331+
messages::ReportError::BatchSaturated,
332+
messages::ReportError::TaskExpired,
333+
messages::ReportError::InvalidMessage,
334+
messages::ReportError::ReportTooEarly,
335+
messages::ReportError::TaskNotStarted,
336+
];
337+
338+
for error in all_errors {
339+
let converted: base_capnp::ReportError = error.into();
340+
let back_converted: messages::ReportError = converted.into();
341+
assert_eq!(
342+
error, back_converted,
343+
"Conversion symmetry failed for {error:?}",
344+
);
345+
}
346+
}
347+
}

crates/daphne-service-utils/src/compute_offload/mod.rs

Lines changed: 126 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use daphne::{
2323
use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode};
2424
use std::{borrow::Cow, ops::Range};
2525

26+
#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Debug))]
2627
pub struct InitializeReports<'s> {
2728
pub hpke_keys: Cow<'s, [HpkeReceiverConfig]>,
2829
/// Output of [`DapAggregator::valid_report_time_range`](daphne::roles::DapAggregator) at the
@@ -489,11 +490,108 @@ fn to_capnp<E: ToString>(e: E) -> capnp::Error {
489490
#[cfg(test)]
490491
mod test {
491492
use super::*;
492-
use crate::capnproto::{CapnprotoPayloadDecodeExt, CapnprotoPayloadEncodeExt};
493+
use crate::capnproto::roundtrip_test;
494+
use daphne::{vdaf::VdafVerifyKey, DapVersion};
495+
use rand::Rng;
493496

494497
#[test]
495498
fn report_metadata_roundtrip() {
496-
let report_metadata = ReportMetadata {
499+
roundtrip_test(generate_random_report_metadata());
500+
}
501+
502+
#[test]
503+
fn test_encode_decode_range() {
504+
let mut rng = rand::thread_rng();
505+
506+
let start = rng.gen::<messages::Time>();
507+
let end = rng.gen::<messages::Time>();
508+
509+
roundtrip_test(start..end);
510+
}
511+
512+
#[test]
513+
fn test_encode_decode_partial_dap_task_config_for_report_init() {
514+
roundtrip_test(generate_random_partial_task_config());
515+
}
516+
517+
#[test]
518+
fn test_encode_decode_prepare_init() {
519+
roundtrip_test(generate_random_prepare_init());
520+
}
521+
522+
#[test]
523+
fn test_encode_decode_hpke_config() {
524+
roundtrip_test(generate_random_hpke_config());
525+
}
526+
527+
#[test]
528+
fn test_encode_decode_hpke_receiver_config() {
529+
roundtrip_test(HpkeReceiverConfig {
530+
config: generate_random_hpke_config(),
531+
private_key: generate_random_bytes().into(),
532+
});
533+
}
534+
535+
#[test]
536+
fn test_encode_decode_initialize_reports() {
537+
let initialize_reports = InitializeReports {
538+
hpke_keys: generate_random_hpke_receiver_configs().into(),
539+
valid_report_range: generate_random_valid_report_range(),
540+
task_id: TaskId(rand::thread_rng().gen()),
541+
task_config: generate_random_partial_task_config(),
542+
agg_param: generate_random_bytes().into(),
543+
prep_inits: generate_random_prep_inits(),
544+
};
545+
546+
roundtrip_test(initialize_reports);
547+
}
548+
549+
fn generate_random_partial_task_config() -> PartialDapTaskConfigForReportInit<'static> {
550+
let mut rng = rand::thread_rng();
551+
PartialDapTaskConfigForReportInit {
552+
not_before: rng.gen::<u64>(),
553+
not_after: rng.gen::<u64>(),
554+
method_is_taskprov: rng.gen::<bool>(),
555+
version: [DapVersion::Latest, DapVersion::Draft09][rng.gen_range(0..2)],
556+
vdaf: Cow::Owned(generate_random_vdaf()),
557+
vdaf_verify_key: VdafVerifyKey::from(rng.gen::<[u8; 32]>()),
558+
}
559+
}
560+
561+
fn generate_random_hpke_receiver_configs() -> Vec<HpkeReceiverConfig> {
562+
(0..3)
563+
.map(|_| generate_random_hpke_receiver_config())
564+
.collect()
565+
}
566+
567+
fn generate_random_hpke_receiver_config() -> HpkeReceiverConfig {
568+
HpkeReceiverConfig {
569+
config: generate_random_hpke_config(),
570+
private_key: generate_random_bytes().into(),
571+
}
572+
}
573+
574+
fn generate_random_valid_report_range() -> std::ops::Range<u64> {
575+
let mut rng = rand::thread_rng();
576+
rng.gen_range(0..1000)..rng.gen_range(1000..2000)
577+
}
578+
579+
fn generate_random_prep_inits() -> Vec<PrepareInit> {
580+
(0..2).map(|_| generate_random_prepare_init()).collect()
581+
}
582+
583+
fn generate_random_hpke_config() -> HpkeConfig {
584+
HpkeConfig {
585+
id: rand::thread_rng().gen(),
586+
kem_id: daphne::hpke::HpkeKemId::P256HkdfSha256,
587+
kdf_id: daphne::hpke::HpkeKdfId::HkdfSha256,
588+
aead_id: daphne::hpke::HpkeAeadId::Aes128Gcm,
589+
public_key: generate_random_bytes().into(),
590+
}
591+
}
592+
593+
fn generate_random_report_metadata() -> ReportMetadata {
594+
ReportMetadata {
497595
id: messages::ReportId(rand::random()),
498596
time: rand::random(),
499597
public_extensions: Some(vec![
@@ -503,27 +601,35 @@ mod test {
503601
payload: b"some extension payload".to_vec(),
504602
},
505603
]),
506-
};
604+
}
605+
}
507606

508-
assert_eq!(
509-
report_metadata,
510-
ReportMetadata::decode_from_bytes(&report_metadata.encode_to_bytes()).unwrap()
511-
);
607+
fn generate_random_bytes() -> Vec<u8> {
608+
let mut rng = rand::thread_rng();
609+
(0..rng.gen_range(0..32)).map(|_| rng.gen()).collect()
512610
}
513611

514-
#[test]
515-
fn report_metadata_roundtrip_draft09() {
516-
let report_metadata = ReportMetadata {
517-
id: messages::ReportId(rand::random()),
518-
time: rand::random(),
519-
// draft09 compatibility: Previously there was no extensions field in the report
520-
// metadata.
521-
public_extensions: None,
522-
};
612+
fn generate_random_vdaf() -> VdafConfig {
613+
// Replace with the actual logic to generate random VdafConfig
614+
VdafConfig::Prio2 { dimension: 1234 }
615+
}
523616

524-
assert_eq!(
525-
report_metadata,
526-
ReportMetadata::decode_from_bytes(&report_metadata.encode_to_bytes()).unwrap()
527-
);
617+
fn generate_random_prepare_init() -> PrepareInit {
618+
PrepareInit {
619+
report_share: generate_random_report_share(),
620+
payload: generate_random_bytes(),
621+
}
622+
}
623+
624+
fn generate_random_report_share() -> ReportShare {
625+
ReportShare {
626+
report_metadata: generate_random_report_metadata(),
627+
public_share: generate_random_bytes(),
628+
encrypted_input_share: HpkeCiphertext {
629+
config_id: rand::thread_rng().gen::<u8>(),
630+
enc: generate_random_bytes(),
631+
payload: generate_random_bytes(),
632+
},
633+
}
528634
}
529635
}

crates/daphne-worker/src/durable/replay_checker.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ super::mk_durable_object! {
3131
///
3232
/// - The report id has been seen but belongs to a different agg job id:
3333
/// => We return it in duplicates set
34+
///
35+
///
36+
/// This is only correct for as long as two invariants are upheld somewhere else in the code:
37+
/// - We stricly require aggregation jobs to not change once submited (this is required per the DAP spec)
38+
/// - All storage operations after this point that rely on replay protection are idempotent.
3439
struct ReplayChecker {
3540
state: State,
3641
env: Env,
@@ -59,6 +64,8 @@ impl GcDurableObject for ReplayChecker {
5964

6065
let mut duplicates = HashSet::new();
6166

67+
// Check the cache for duplicates and compute the set of report IDs we need to read
68+
// from the disk.
6269
let report_ids_as_string = report_ids
6370
.iter()
6471
.filter(|r| match self.seen.get(r) {
@@ -73,7 +80,7 @@ impl GcDurableObject for ReplayChecker {
7380
.map(ToString::to_string)
7481
.collect::<Vec<_>>();
7582

76-
let aggregation_job_id_as_str = agg_job_id.to_string();
83+
let agg_job_id_as_str = agg_job_id.to_string();
7784

7885
let result = self
7986
.state
@@ -87,14 +94,14 @@ impl GcDurableObject for ReplayChecker {
8794

8895
let v = result.get(&JsValue::from_str(as_str));
8996
if let Some(stored_agg_job_id) = v.as_string() {
90-
if stored_agg_job_id != aggregation_job_id_as_str {
97+
if stored_agg_job_id != agg_job_id_as_str {
9198
duplicates.insert(*id);
9299
}
93100
} else {
94101
js_sys::Reflect::set(
95102
&obj_to_update,
96103
&JsValue::from_str(as_str),
97-
&JsValue::from_str(aggregation_job_id_as_str.as_ref()),
104+
&JsValue::from_str(agg_job_id_as_str.as_ref()),
98105
)?;
99106
}
100107
}

crates/daphne/src/protocol/report_init.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ impl<'s> From<&'s PartialDapTaskConfigForReportInit<'_>> for PartialDapTaskConfi
123123
}
124124
}
125125

126+
#[derive(Clone)]
127+
#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq, Debug))]
126128
pub struct PartialDapTaskConfigForReportInit<'s> {
127129
pub not_before: messages::Time,
128130
pub not_after: messages::Time,

0 commit comments

Comments
 (0)