diff --git a/Cargo.lock b/Cargo.lock index 1ef81975f..10a881bcb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -926,6 +926,7 @@ dependencies = [ "prio 0.17.0-alpha.0", "prometheus", "rand", + "rayon", "rcgen", "reqwest", "serde", @@ -952,6 +953,7 @@ dependencies = [ "prio 0.17.0-alpha.0", "rand", "serde", + "serde_json", "url", ] @@ -1003,10 +1005,12 @@ dependencies = [ name = "daphne-worker-test" version = "0.3.0" dependencies = [ + "async-trait", "cap", "cfg-if", "console_error_panic_hook", "daphne-worker", + "futures", "prometheus", "tracing", "worker", diff --git a/crates/dapf/src/acceptance/load_testing.rs b/crates/dapf/src/acceptance/load_testing.rs index cf0d9966c..c5aea384e 100644 --- a/crates/dapf/src/acceptance/load_testing.rs +++ b/crates/dapf/src/acceptance/load_testing.rs @@ -448,6 +448,7 @@ pub async fn execute_single_combination_from_env( &measurment, VERSION, system_now.0, + Some(vec![]), vec![messages::Extension::Taskprov], t.replay_reports, ) diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index dc82028ab..91ef43073 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -655,6 +655,10 @@ impl Test { measurement.as_ref(), version, now.0, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![messages::Extension::Taskprov], self.replay_reports, ) diff --git a/crates/daphne-server/tests/e2e/e2e.rs b/crates/daphne-server/tests/e2e/e2e.rs index 5d7057d11..ce27fbe9a 100644 --- a/crates/daphne-server/tests/e2e/e2e.rs +++ b/crates/daphne-server/tests/e2e/e2e.rs @@ -304,6 +304,10 @@ async fn leader_upload(version: DapVersion) { report_metadata: ReportMetadata { id: ReportId([1; 16]), time: t.now, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_shares: [ @@ -424,6 +428,10 @@ async fn leader_upload_taskprov() { t.now, &task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::Taskprov], version, ) @@ -451,6 +459,10 @@ async fn leader_upload_taskprov() { t.now, &task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::Taskprov], version, ) @@ -516,6 +528,10 @@ async fn leader_upload_taskprov_wrong_version(version: DapVersion) { t.now, &task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::Taskprov], version, ) @@ -541,6 +557,100 @@ async fn leader_upload_taskprov_wrong_version(version: DapVersion) { async_test_versions!(leader_upload_taskprov_wrong_version); +#[tokio::test] +async fn leader_upload_taksprov_public_errors() { + let version = DapVersion::Latest; + let t = TestRunner::default_with_version(version).await; + let client = t.http_client(); + let hpke_config_list = t.get_hpke_configs(version, client).await.unwrap(); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 10, + query: DapBatchMode::TimeInterval, + leader_url: t.task_config.leader_url.clone(), + helper_url: t.task_config.helper_url.clone(), + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + daphne::roles::aggregator::TaskprovConfig { + hpke_collector_config: &t.taskprov_collector_hpke_receiver.config, + vdaf_verify_key_init: &t.taskprov_vdaf_verify_key_init, + }, + ) + .unwrap(); + + // Repeated public extension + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![Extension::Taskprov, Extension::Taskprov]), + vec![], + version, + ) + .unwrap(); + t.leader_request_expect_abort( + client, + None, + &format!("tasks/{}/reports", task_id.to_base64url()), + &http::Method::POST, + DapMediaType::Report, + Some( + &taskprov_advertisement + .serialize_to_header_value(version) + .unwrap(), + ), + report.get_encoded_with_param(&version).unwrap(), + 400, + "invalidMessage", + ) + .await + .unwrap(); + + // Unsupported public extension + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![ + Extension::Taskprov, + Extension::NotImplemented { + typ: 3, + payload: b"ignore".to_vec(), + }, + ]), + vec![], + version, + ) + .unwrap(); + t.leader_request_expect_abort( + client, + None, + &format!("tasks/{}/reports", task_id.to_base64url()), + &http::Method::POST, + DapMediaType::Report, + Some( + &taskprov_advertisement + .serialize_to_header_value(version) + .unwrap(), + ), + report.get_encoded_with_param(&version).unwrap(), + 400, + "unsupportedExtension", + ) + .await + .unwrap(); +} + async fn internal_leader_process(version: DapVersion) { let t = TestRunner::default_with_version(version).await; let path = t.upload_path(); @@ -1408,6 +1518,10 @@ async fn leader_collect_taskprov_ok(version: DapVersion) { now, &task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, extensions, version, ) diff --git a/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp b/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp index 3fa6bea48..19f1858a0 100644 --- a/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp +++ b/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp @@ -53,10 +53,24 @@ struct PartialDapTaskConfig @0xdcc9bf18fc62d406 { vdafVerifyKey @4 :VdafVerifyKey; } +struct PublicExtensionsList @0x8b3c98c0ddd0043e { + + union { + # Each extension is encoded according to the DAP spec in + # tag-length-value form. + list @0 :List(Data); + + # draft09 compatibility: Previously DAP had no extensions in the + # report. + none @1 :Void; + } +} + struct ReportMetadata @0xefba178ad4584bc4 { - id @0 :Base.ReportId; - time @1 :Base.Time; + id @0 :Base.ReportId; + time @1 :Base.Time; + publicExtensions @2 :PublicExtensionsList; } struct PrepareInit @0x8192568cb3d03f59 { diff --git a/crates/daphne-service-utils/src/compute_offload/mod.rs b/crates/daphne-service-utils/src/compute_offload/mod.rs index dcc1485f6..6a1489128 100644 --- a/crates/daphne-service-utils/src/compute_offload/mod.rs +++ b/crates/daphne-service-utils/src/compute_offload/mod.rs @@ -10,17 +10,17 @@ use crate::{ hpke_receiver_config::{self, hpke_config}, initialize_reports, initialized_reports::{self, initialized_report}, - partial_dap_task_config, prepare_init, report_metadata, time_range, + partial_dap_task_config, prepare_init, public_extensions_list, report_metadata, time_range, }, }; use daphne::{ constants::DapAggregatorRole, hpke::{HpkeConfig, HpkeReceiverConfig}, - messages::{self, HpkeCiphertext, PrepareInit, ReportMetadata, ReportShare, TaskId}, + messages::{self, Extension, HpkeCiphertext, PrepareInit, ReportMetadata, ReportShare, TaskId}, vdaf::{VdafConfig, VdafPrepShare, VdafPrepState}, InitializedReport, PartialDapTaskConfigForReportInit, WithPeerPrepShare, }; -use prio::codec::{Encode, ParameterizedDecode, ParameterizedEncode}; +use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode}; use std::{borrow::Cow, ops::Range}; pub struct InitializeReports<'s> { @@ -318,9 +318,27 @@ impl CapnprotoPayloadEncode for ReportMetadata { type Builder<'a> = report_metadata::Builder<'a>; fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { - let Self { id, time } = self; + let Self { + id, + time, + public_extensions, + } = self; id.encode_to_builder(builder.reborrow().init_id()); builder.set_time(*time); + if let Some(ref extensions) = public_extensions { + let mut e = builder + .init_public_extensions() + .init_list(usize_to_capnp_len(extensions.len())); + for (i, data) in extensions + .iter() + .enumerate() + .map(|(i, ext)| (usize_to_capnp_len(i), ext.get_encoded().unwrap())) + { + e.reborrow().set(i, &data); + } + } else { + builder.init_public_extensions().set_none(()); + } } } @@ -331,9 +349,25 @@ impl CapnprotoPayloadDecode for ReportMetadata { where Self: Sized, { + let id = <_>::decode_from_reader(reader.get_id()?)?; + let time = reader.get_time(); + let public_extensions = match reader.get_public_extensions()?.which()? { + public_extensions_list::List(list) => Some( + list? + .into_iter() + .map(|data| { + Extension::get_decoded(data?) + .map_err(|e| capnp::Error::failed(e.to_string())) + }) + .collect::, capnp::Error>>()?, + ), + public_extensions_list::None(()) => None, + }; + Ok(Self { - id: <_>::decode_from_reader(reader.get_id()?)?, - time: reader.get_time(), + id, + time, + public_extensions, }) } } @@ -486,3 +520,45 @@ fn to_capnp(e: E) -> capnp::Error { extra: e.to_string(), } } + +#[cfg(test)] +mod test { + use super::*; + use crate::capnproto::{CapnprotoPayloadDecodeExt, CapnprotoPayloadEncodeExt}; + + #[test] + fn report_metadata_roundtrip() { + let report_metadata = ReportMetadata { + id: messages::ReportId(rand::random()), + time: rand::random(), + public_extensions: Some(vec![ + Extension::Taskprov, + Extension::NotImplemented { + typ: 23, + payload: b"some extension payload".to_vec(), + }, + ]), + }; + + assert_eq!( + report_metadata, + ReportMetadata::decode_from_bytes(&report_metadata.encode_to_bytes()).unwrap() + ); + } + + #[test] + fn report_metadata_roundtrip_draft09() { + let report_metadata = ReportMetadata { + id: messages::ReportId(rand::random()), + time: rand::random(), + // draft09 compatibility: Previously there was no extensions field in the report + // metadata. + public_extensions: None, + }; + + assert_eq!( + report_metadata, + ReportMetadata::decode_from_bytes(&report_metadata.encode_to_bytes()).unwrap() + ); + } +} diff --git a/crates/daphne/src/error/aborts.rs b/crates/daphne/src/error/aborts.rs index c560e4757..45f56f448 100644 --- a/crates/daphne/src/error/aborts.rs +++ b/crates/daphne/src/error/aborts.rs @@ -95,6 +95,10 @@ pub enum DapAbort { /// Unrecognized DAP task. Sent in response to a request indicating an unrecognized task ID. #[error("unrecognizedTask")] UnrecognizedTask { task_id: TaskId }, + + /// Unsupported Extension. Sent in response to a report upload with an unsupported extension. + #[error("unsupportedExtension")] + UnsupportedExtension { detail: String, task_id: TaskId }, } impl DapAbort { @@ -116,7 +120,8 @@ impl DapAbort { | Self::InvalidBatchSize { detail, task_id } | Self::BatchModeMismatch { detail, task_id } | Self::UnauthorizedRequest { detail, task_id } - | Self::InvalidMessage { detail, task_id } => ( + | Self::InvalidMessage { detail, task_id } + | Self::UnsupportedExtension { detail, task_id } => ( Some(task_id), Some(detail), None, @@ -259,6 +264,16 @@ impl DapAbort { }) } + pub fn unsupported_extension( + task_id: &TaskId, + unknown_extensions: &[u16], + ) -> Result { + Ok(Self::UnsupportedExtension { + detail: format!("{unknown_extensions:?}"), + task_id: *task_id, + }) + } + fn title_and_type(&self) -> (&'static str, Option) { let (title, dap_abort_type) = match self { Self::BatchInvalid { .. } => ("Batch boundary check failed", Some(self.to_string())), @@ -300,6 +315,9 @@ impl DapAbort { Some(self.to_string()), ), Self::BadRequest(..) => ("Bad request", None), + Self::UnsupportedExtension { .. } => { + ("Unsupported extensions in report", Some(self.to_string())) + } }; ( diff --git a/crates/daphne/src/hpke.rs b/crates/daphne/src/hpke.rs index 2e3c8535e..aac757d61 100644 --- a/crates/daphne/src/hpke.rs +++ b/crates/daphne/src/hpke.rs @@ -612,6 +612,10 @@ mod test { report_metadata: &ReportMetadata { id: ReportId(rand::random()), time: rand::random(), + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, }; let plaintext = b"plaintext"; @@ -703,6 +707,10 @@ mod test { let report_metadata = &ReportMetadata { id: ReportId(rand::random()), time: rand::random(), + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }; let public_share = &vec![rand::random(); (0..100).choose(&mut rand::thread_rng()).unwrap()]; diff --git a/crates/daphne/src/messages/mod.rs b/crates/daphne/src/messages/mod.rs index f7a9ceaeb..8b86e2cb6 100644 --- a/crates/daphne/src/messages/mod.rs +++ b/crates/daphne/src/messages/mod.rs @@ -177,7 +177,7 @@ pub type Duration = u64; pub type Time = u64; /// Report extensions. -#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize, Hash)] #[serde(rename_all = "snake_case")] #[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] pub enum Extension { @@ -195,16 +195,13 @@ impl Extension { } } -impl ParameterizedEncode for Extension { - fn encode_with_param( - &self, - version: &DapVersion, - bytes: &mut Vec, - ) -> Result<(), CodecError> { +impl Encode for Extension { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { match self { Self::Taskprov => { EXTENSION_TASKPROV.encode(bytes)?; - encode_u16_prefixed(*version, bytes, |_, _| Ok(()))?; + // We've hard coded the version here, but we don't actually use it. + encode_u16_prefixed(DapVersion::Latest, bytes, |_, _| Ok(()))?; } Self::NotImplemented { typ, payload } => { typ.encode(bytes)?; @@ -215,15 +212,15 @@ impl ParameterizedEncode for Extension { } } -impl ParameterizedDecode for Extension { - fn decode_with_param( - version: &DapVersion, - bytes: &mut Cursor<&[u8]>, - ) -> Result { +impl Decode for Extension { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let typ = u16::decode(bytes)?; match typ { EXTENSION_TASKPROV => { - decode_u16_prefixed(*version, bytes, |_version, inner, _len| <()>::decode(inner))?; + // We've hard coded the version here, but we don't actually use it. + decode_u16_prefixed(DapVersion::Latest, bytes, |_version, inner, _len| { + <()>::decode(inner) + })?; Ok(Self::Taskprov) } _ => Ok(Self::NotImplemented { @@ -240,28 +237,45 @@ impl ParameterizedDecode for Extension { pub struct ReportMetadata { pub id: ReportId, pub time: Time, + pub public_extensions: Option>, } impl ParameterizedEncode for ReportMetadata { fn encode_with_param( &self, - _version: &DapVersion, + version: &DapVersion, bytes: &mut Vec, ) -> Result<(), CodecError> { self.id.encode(bytes)?; self.time.encode(bytes)?; + match (version, &self.public_extensions) { + (DapVersion::Draft09, None) => (), + (DapVersion::Latest, Some(extensions)) => { + encode_u16_items(bytes, version, extensions.as_slice())?; + } + _ => { + return Err(CodecError::Other( + "encountered incorrectly set public extensions".into(), + )) + } + } + Ok(()) } } impl ParameterizedDecode for ReportMetadata { fn decode_with_param( - _version: &DapVersion, + version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { let metadata = Self { id: ReportId::decode(bytes)?, time: Time::decode(bytes)?, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(decode_u16_items(version, bytes)?), + }, }; Ok(metadata) @@ -1406,7 +1420,7 @@ impl Decode for HpkeCiphertext { /// A plaintext input share. #[derive(Clone, Debug, PartialEq, Eq)] pub struct PlaintextInputShare { - pub extensions: Vec, + pub private_extensions: Vec, pub payload: Vec, } @@ -1416,7 +1430,7 @@ impl ParameterizedEncode for PlaintextInputShare { version: &DapVersion, bytes: &mut Vec, ) -> Result<(), CodecError> { - encode_u16_items(bytes, version, &self.extensions)?; + encode_u16_items(bytes, version, &self.private_extensions)?; encode_u32_bytes(bytes, &self.payload)?; Ok(()) } @@ -1428,7 +1442,7 @@ impl ParameterizedDecode for PlaintextInputShare { bytes: &mut Cursor<&[u8]>, ) -> Result { Ok(Self { - extensions: decode_u16_items(version, bytes)?, + private_extensions: decode_u16_items(version, bytes)?, payload: decode_u32_bytes(bytes)?, }) } @@ -1570,6 +1584,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([23; 16]), time: 1_637_364_244, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_shares: [ @@ -1597,6 +1615,59 @@ mod test { test_versions! {read_report} + fn report_metadata_encode_decode(version: DapVersion) { + let ext_rm = ReportMetadata { + id: ReportId([15; 16]), + time: 123_456, + public_extensions: Some(vec![Extension::NotImplemented { + typ: 0x10, + payload: vec![0x11, 0x12], + }]), + }; + let no_ext_rm = ReportMetadata { + id: ReportId([13; 16]), + time: 123_456, + public_extensions: None, + }; + let good_rm = match version { + DapVersion::Draft09 => &no_ext_rm, + DapVersion::Latest => &ext_rm, + }; + let bad_rm = match version { + DapVersion::Draft09 => &ext_rm, + DapVersion::Latest => &no_ext_rm, + }; + assert!(matches!( + bad_rm.get_encoded_with_param(&version).unwrap_err(), + CodecError::Other(..) + )); + let bytes = good_rm.get_encoded_with_param(&version).unwrap(); + assert_eq!( + ReportMetadata::get_decoded_with_param(&version, bytes.as_slice()).unwrap(), + *good_rm + ); + } + + test_versions! {report_metadata_encode_decode} + + #[test] + fn report_metadata_encode_latest_decode_draft09() { + let ext_rm = ReportMetadata { + id: ReportId([15; 16]), + time: 123_456, + public_extensions: Some(vec![Extension::NotImplemented { + typ: 0x10, + payload: vec![0x11, 0x12], + }]), + }; + let bytes = ext_rm.get_encoded_with_param(&DapVersion::Latest).unwrap(); + assert!(matches!( + ReportMetadata::get_decoded_with_param(&DapVersion::Draft09, bytes.as_slice()) + .unwrap_err(), + CodecError::BytesLeftOver(..) + )); + } + fn partial_batch_selector_encode_decode(version: DapVersion) { const TEST_DATA_DRAFT09: &[u8] = &[1]; const TEST_DATA_LATEST: &[u8] = &[1, 0, 0]; @@ -1684,14 +1755,15 @@ mod test { 0, 0, 0, 32, 116, 104, 105, 115, 32, 105, 115, 32, 97, 110, 32, 97, 103, 103, 114, 101, 103, 97, 116, 105, 111, 110, 32, 112, 97, 114, 97, 109, 101, 116, 101, 114, 2, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 158, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, - 0, 0, 0, 0, 97, 152, 38, 185, 0, 0, 0, 12, 112, 117, 98, 108, 105, 99, 32, 115, 104, - 97, 114, 101, 23, 0, 16, 101, 110, 99, 97, 112, 115, 117, 108, 97, 116, 101, 100, 32, - 107, 101, 121, 0, 0, 0, 10, 99, 105, 112, 104, 101, 114, 116, 101, 120, 116, 0, 0, 0, - 10, 112, 114, 101, 112, 32, 115, 104, 97, 114, 101, 17, 17, 17, 17, 17, 17, 17, 17, 17, - 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 9, 194, 107, 103, 0, 0, 0, 12, 112, 117, 98, - 108, 105, 99, 32, 115, 104, 97, 114, 101, 0, 0, 0, 0, 0, 0, 10, 99, 105, 112, 104, 101, - 114, 116, 101, 120, 116, 0, 0, 0, 10, 112, 114, 101, 112, 32, 115, 104, 97, 114, 101, + 0, 0, 0, 0, 0, 0, 162, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, + 0, 0, 0, 0, 97, 152, 38, 185, 0, 0, 0, 0, 0, 12, 112, 117, 98, 108, 105, 99, 32, 115, + 104, 97, 114, 101, 23, 0, 16, 101, 110, 99, 97, 112, 115, 117, 108, 97, 116, 101, 100, + 32, 107, 101, 121, 0, 0, 0, 10, 99, 105, 112, 104, 101, 114, 116, 101, 120, 116, 0, 0, + 0, 10, 112, 114, 101, 112, 32, 115, 104, 97, 114, 101, 17, 17, 17, 17, 17, 17, 17, 17, + 17, 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 9, 194, 107, 103, 0, 0, 0, 0, 0, 12, 112, + 117, 98, 108, 105, 99, 32, 115, 104, 97, 114, 101, 0, 0, 0, 0, 0, 0, 10, 99, 105, 112, + 104, 101, 114, 116, 101, 120, 116, 0, 0, 0, 10, 112, 114, 101, 112, 32, 115, 104, 97, + 114, 101, ]; let want = AggregationJobInitReq { @@ -1705,6 +1777,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([99; 16]), time: 1_637_361_337, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_share: HpkeCiphertext { @@ -1720,6 +1796,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([17; 16]), time: 163_736_423, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_share: HpkeCiphertext { @@ -1759,6 +1839,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([99; 16]), time: 1_637_361_337, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_share: HpkeCiphertext { @@ -1774,6 +1858,10 @@ mod test { report_metadata: ReportMetadata { id: ReportId([17; 16]), time: 163_736_423, + public_extensions: match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, }, public_share: b"public share".to_vec(), encrypted_input_share: HpkeCiphertext { diff --git a/crates/daphne/src/protocol/aggregator.rs b/crates/daphne/src/protocol/aggregator.rs index d96b88fb2..b169f3664 100644 --- a/crates/daphne/src/protocol/aggregator.rs +++ b/crates/daphne/src/protocol/aggregator.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause use super::{ - no_duplicates, + check_no_duplicates, report_init::{InitializedReport, WithPeerPrepShare}, }; use crate::{ @@ -241,7 +241,7 @@ impl DapTaskConfig { DapAggregationParam::get_decoded_with_param(&self.vdaf, &agg_job_init_req.agg_param) .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; if replay_protection.enabled() { - no_duplicates( + check_no_duplicates( agg_job_init_req .prep_inits .iter() diff --git a/crates/daphne/src/protocol/client.rs b/crates/daphne/src/protocol/client.rs index d0f707b4f..6358d4821 100644 --- a/crates/daphne/src/protocol/client.rs +++ b/crates/daphne/src/protocol/client.rs @@ -3,6 +3,7 @@ use crate::{ constants::DapAggregatorRole, + fatal_error, hpke::{info_and_aad, HpkeConfig}, messages::{Extension, PlaintextInputShare, Report, ReportId, ReportMetadata, TaskId, Time}, DapError, DapMeasurement, DapVersion, VdafConfig, @@ -30,13 +31,15 @@ impl VdafConfig { /// * `extensions` are the extensions. /// /// * `version` is the `DapVersion` to use. + #[allow(clippy::too_many_arguments)] pub fn produce_report_with_extensions( &self, hpke_config_list: &[HpkeConfig; 2], time: Time, task_id: &TaskId, measurement: DapMeasurement, - extensions: Vec, + public_extensions: Option>, + private_extensions: Vec, version: DapVersion, ) -> Result { let mut rng = thread_rng(); @@ -51,7 +54,8 @@ impl VdafConfig { time, task_id, &report_id, - extensions, + public_extensions, + private_extensions, version, ) } @@ -65,17 +69,27 @@ impl VdafConfig { time: Time, task_id: &TaskId, report_id: &ReportId, - extensions: Vec, + public_extensions: Option>, + private_extensions: Vec, version: DapVersion, ) -> Result { + if let (Some(_), DapVersion::Draft09) | (None, DapVersion::Latest) = + (&public_extensions, version) + { + return Err(fatal_error!( + err = format!("public extensions not set correctly for {version:?}") + )); + } + let mut plaintext_input_share = PlaintextInputShare { - extensions, + private_extensions, payload: Vec::default(), }; let metadata = ReportMetadata { id: *report_id, time, + public_extensions, }; let encoded_input_shares = input_shares.into_iter().map(|input_share| { @@ -143,6 +157,10 @@ impl VdafConfig { time, task_id, measurement, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(Vec::new()), + }, Vec::new(), version, ) diff --git a/crates/daphne/src/protocol/mod.rs b/crates/daphne/src/protocol/mod.rs index 3dbb9832e..2f39c3efa 100644 --- a/crates/daphne/src/protocol/mod.rs +++ b/crates/daphne/src/protocol/mod.rs @@ -11,7 +11,7 @@ pub(crate) mod report_init; /// checks if an iterator has no duplicate items, returns the ok if there are no dups or an error /// with the first offending item. -pub(crate) fn no_duplicates(iterator: I) -> Result<(), I::Item> +pub(crate) fn check_no_duplicates(iterator: I) -> Result<(), I::Item> where I: Iterator, I::Item: Eq + std::hash::Hash, @@ -752,6 +752,10 @@ mod test { t.now, &t.task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::NotImplemented { typ: 0xffff, payload: b"some extension data".to_vec(), @@ -784,6 +788,58 @@ mod test { test_versions! { handle_unrecognized_report_extensions } + #[test] + fn handle_unknown_public_extensions_in_report() { + let version = DapVersion::Latest; + let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); + let report = t + .task_config + .vdaf + .produce_report_with_extensions( + &t.client_hpke_config_list, + t.now, + &t.task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![ + Extension::NotImplemented { + typ: 0x01, + payload: b"This is ignored".to_vec(), + }, + Extension::NotImplemented { + typ: 0x02, + payload: b"This is ignored too".to_vec(), + }, + ]), + vec![], + version, + ) + .unwrap(); + let report_metadata = report.report_metadata.clone(); + let [leader_share, _] = report.encrypted_input_shares; + let initialized_report = InitializedReport::from_client( + &t.leader_hpke_receiver_config, + t.valid_report_time_range(), + &t.task_id, + &t.task_config, + ReportShare { + report_metadata: report.report_metadata, + public_share: report.public_share, + encrypted_input_share: leader_share, + }, + &DapAggregationParam::Empty, + ) + .unwrap(); + + assert_eq!(initialized_report.metadata(), &report_metadata); + assert_matches!( + initialized_report, + InitializedReport::Rejected { + report_err: ReportError::InvalidMessage, + .. + } + ); + } + fn handle_repeated_report_extensions(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let report = t @@ -794,6 +850,10 @@ mod test { t.now, &t.task_id, DapMeasurement::U32Vec(vec![1; 10]), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![ Extension::NotImplemented { typ: 23, @@ -856,6 +916,10 @@ mod test { self.now, &self.task_id, &report_id, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, Vec::new(), // extensions version, ) @@ -887,6 +951,10 @@ mod test { self.now, &self.task_id, &report_id, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, Vec::new(), // extensions version, ) @@ -919,6 +987,10 @@ mod test { self.now, &self.task_id, &report_id, + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, Vec::new(), // extensions version, ) diff --git a/crates/daphne/src/protocol/report_init.rs b/crates/daphne/src/protocol/report_init.rs index 82adf2945..7502ab554 100644 --- a/crates/daphne/src/protocol/report_init.rs +++ b/crates/daphne/src/protocol/report_init.rs @@ -7,7 +7,7 @@ use crate::{ messages::{ self, Extension, PlaintextInputShare, ReportError, ReportMetadata, ReportShare, TaskId, }, - protocol::{decode_ping_pong_framed, no_duplicates, PingPongMessageType}, + protocol::{check_no_duplicates, decode_ping_pong_framed, PingPongMessageType}, vdaf::{VdafConfig, VdafPrepShare, VdafPrepState, VdafVerifyKey}, DapAggregationParam, DapError, DapTaskConfig, DapVersion, }; @@ -158,9 +158,16 @@ impl

InitializedReport

{ _ => {} } + match ( + &report_share.report_metadata.public_extensions, + task_config.version, + ) { + (Some(..), crate::DapVersion::Latest) | (None, crate::DapVersion::Draft09) => (), + (_, _) => reject!(InvalidMessage), + } // decrypt input share let PlaintextInputShare { - extensions, + private_extensions, payload: input_share, } = { let info = info_and_aad::InputShare { @@ -194,18 +201,38 @@ impl

InitializedReport

{ // Handle report extensions. { - if no_duplicates(extensions.iter().map(|e| e.type_code())).is_err() { + // Check for duplicates in public and private extensions + if check_no_duplicates( + private_extensions + .iter() + .chain( + report_share + .report_metadata + .public_extensions + .as_deref() + .unwrap_or_default(), + ) + .map(|e| e.type_code()), + ) + .is_err() + { reject!(InvalidMessage) } + let mut taskprov_indicated = false; - for extension in extensions { + for extension in private_extensions.iter().chain( + report_share + .report_metadata + .public_extensions + .as_deref() + .unwrap_or_default(), + ) { match extension { - Extension::Taskprov { .. } if task_config.method_is_taskprov => { - taskprov_indicated = true; + Extension::Taskprov { .. } => { + taskprov_indicated = task_config.method_is_taskprov; } - // Reject reports with unrecognized extensions. - _ => reject!(InvalidMessage), + Extension::NotImplemented { .. } => reject!(InvalidMessage), } } diff --git a/crates/daphne/src/roles/helper/handle_agg_job.rs b/crates/daphne/src/roles/helper/handle_agg_job.rs index 699abe45c..026a399b5 100644 --- a/crates/daphne/src/roles/helper/handle_agg_job.rs +++ b/crates/daphne/src/roles/helper/handle_agg_job.rs @@ -179,7 +179,7 @@ impl HandleAggJob { > { let task_id = self.state.request.task_id; if replay_protection.enabled() { - crate::protocol::no_duplicates( + crate::protocol::check_no_duplicates( self.state .request .payload diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index aed51c207..c65c0a8f8 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -20,9 +20,10 @@ use crate::{ messages::{ taskprov::TaskprovAdvertisement, AggregateShare, AggregateShareReq, AggregationJobId, AggregationJobResp, Base64Encode, BatchId, BatchSelector, Collection, CollectionJobId, - CollectionReq, Interval, PartialBatchSelector, Query, Report, TaskId, + CollectionReq, Extension, Interval, PartialBatchSelector, Query, Report, TaskId, }, metrics::{DaphneRequestType, ReportStatus}, + protocol, roles::resolve_task_config, DapAggregationParam, DapCollectionJob, DapError, DapLeaderProcessTelemetry, DapRequest, DapRequestMeta, DapResponse, DapTaskConfig, DapVersion, @@ -214,6 +215,8 @@ pub async fn handle_upload_req( } .into()); } + + // Check that the report was generated after the task's `not_before` time. if report.report_metadata.time < task_config.as_ref().not_before - task_config.as_ref().time_precision { @@ -223,6 +226,34 @@ pub async fn handle_upload_req( .into()); } + if let Some(public_extensions) = &report.report_metadata.public_extensions { + // We can be sure at this point that the ReportMetadata is well formed + // because the decoding / error checking happens in the extractor. + assert_eq!(DapVersion::Latest, task_config.version); + let mut unknown_extensions = Vec::::new(); + if protocol::check_no_duplicates(public_extensions.iter()).is_err() { + return Err(DapError::Abort(DapAbort::InvalidMessage { + detail: "Repeated public extension".into(), + task_id, + })); + }; + for extension in public_extensions { + match extension { + Extension::Taskprov => (), + Extension::NotImplemented { typ, .. } => unknown_extensions.push(*typ), + } + } + + if !unknown_extensions.is_empty() { + return match DapAbort::unsupported_extension(&task_id, &unknown_extensions) { + Ok(abort) => Err::<(), DapError>(abort.into()), + Err(fatal) => Err(fatal), + }; + } + } else { + assert_eq!(DapVersion::Draft09, task_config.version); + } + // Store the report for future processing. At this point, the report may be rejected if // the Leader detects that the report was replayed or pertains to a batch that has already // been collected. diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index 56c3df982..1f4dadd11 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -772,6 +772,87 @@ mod test { async_test_versions! { handle_agg_job_req_failure_hpke_decrypt_error } + #[tokio::test] + async fn handle_unknown_public_extensions() { + let version = DapVersion::Latest; + let t = Test::new(version); + let task_id = &t.time_interval_task_id; + let task_config = t.leader.unchecked_get_task_config(task_id).await; + + // Construct HPKE config list. + let hpke_config_list = [ + t.leader + .get_hpke_config_for(task_config.version, Some(task_id)) + .await + .unwrap() + .clone(), + t.helper + .get_hpke_config_for(task_config.version, Some(task_id)) + .await + .unwrap() + .clone(), + ]; + + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now, + task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![Extension::NotImplemented { + typ: 0x01, + payload: vec![0x01], + }]), + vec![], + task_config.version, + ) + .unwrap(); + + let req = DapRequest { + meta: DapRequestMeta { + version: task_config.version, + media_type: Some(DapMediaType::Report), + task_id: *task_id, + ..Default::default() + }, + resource_id: (), + payload: report, + }; + assert_eq!( + leader::handle_upload_req(&*t.leader, req).await, + Err(DapError::Abort(DapAbort::UnsupportedExtension { + detail: "[1]".into(), + task_id: *task_id + })) + ); + } + + #[tokio::test] + #[should_panic(expected = "assertion `left == right` failed\n left: Latest\n right: Draft09")] + async fn handle_public_extensions_draft09() { + let version = DapVersion::Draft09; + let t = Test::new(version); + let task_id = &t.time_interval_task_id; + let task_config = t.leader.unchecked_get_task_config(task_id).await; + let mut report = t.gen_test_report(task_id).await; + // This change breaks the HPKE decryption, but triggers a failure + // before the HPKE data is checked. + report.report_metadata.public_extensions = Some(vec![]); + + let req = DapRequest { + meta: DapRequestMeta { + version: task_config.version, + media_type: Some(DapMediaType::Report), + task_id: *task_id, + ..Default::default() + }, + resource_id: (), + payload: report, + }; + _ = leader::handle_upload_req(&*t.leader, req).await; + } + async fn handle_agg_job_req_transition_continue(version: DapVersion) { let t = Test::new(version); let task_id = &t.time_interval_task_id; @@ -1484,6 +1565,10 @@ mod test { t.now, &task_id, test_measurement.clone(), + match version { + DapVersion::Draft09 => None, + DapVersion::Latest => Some(vec![]), + }, vec![Extension::Taskprov], task_config.version, ) @@ -1600,6 +1685,271 @@ mod test { .await; } + #[tokio::test] + async fn leader_upload_taskprov_public() { + let version = DapVersion::Latest; + let t = Test::new(DapVersion::Latest); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 1, + query: DapBatchMode::LeaderSelected { + max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + t.leader.get_taskprov_config().unwrap(), + ) + .unwrap(); + + // Clients: Send upload request to Leader. + let hpke_config_list = [ + t.leader + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + t.helper + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + ]; + + for _ in 0..task_config.min_batch_size { + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![Extension::Taskprov]), + vec![], + version, + ) + .unwrap(); + let req = DapRequest { + meta: DapRequestMeta { + version, + media_type: Some(DapMediaType::Report), + task_id, + taskprov_advertisement: Some(taskprov_advertisement.clone()), + }, + resource_id: (), + payload: report, + }; + leader::handle_upload_req(&*t.leader, req).await.unwrap(); + } + // Collector: Request result from the Leader. + let query = Query::LeaderSelectedCurrentBatch; + leader::handle_coll_job_req(&*t.leader, &t.gen_test_coll_job_req(query, &task_id).await) + .await + .unwrap(); + + leader::process(&*t.leader, "leader.com", 100) + .await + .unwrap(); + + assert_metrics_include!(t.helper_registry, { + r#"inbound_request_counter{env="test_helper",host="helper.org",type="aggregate"}"#: 1, + r#"inbound_request_counter{env="test_helper",host="helper.org",type="collect"}"#: 1, + r#"report_counter{env="test_helper",host="helper.org",status="aggregated"}"#: 1, + r#"report_counter{env="test_helper",host="helper.org",status="collected"}"#: 1, + r#"aggregation_job_counter{env="test_helper",host="helper.org",status="started"}"#: 1, + r#"aggregation_job_counter{env="test_helper",host="helper.org",status="completed"}"#: 1, + }); + assert_metrics_include!(t.leader_registry, { + r#"report_counter{env="test_leader",host="leader.com",status="aggregated"}"#: 1, + r#"report_counter{env="test_leader",host="leader.com",status="collected"}"#: 1, + }); + } + + #[tokio::test] + async fn leader_upload_taskprov_public_extension_errors() { + let version = DapVersion::Latest; + let t = Test::new(DapVersion::Latest); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 1, + query: DapBatchMode::LeaderSelected { + max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + t.leader.get_taskprov_config().unwrap(), + ) + .unwrap(); + + // Clients: Send upload request to Leader. + let hpke_config_list = [ + t.leader + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + t.helper + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + ]; + + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![Extension::Taskprov, Extension::Taskprov]), + vec![], + version, + ) + .unwrap(); + let req = DapRequest { + meta: DapRequestMeta { + version, + media_type: Some(DapMediaType::Report), + task_id, + taskprov_advertisement: Some(taskprov_advertisement.clone()), + }, + resource_id: (), + payload: report, + }; + assert_eq!( + DapError::Abort(DapAbort::InvalidMessage { + detail: "Repeated public extension".into(), + task_id, + }), + leader::handle_upload_req(&*t.leader, req) + .await + .unwrap_err() + ); + + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![ + Extension::Taskprov, + Extension::NotImplemented { + typ: 14, + payload: b"Ignore".into(), + }, + Extension::NotImplemented { + typ: 15, + payload: b"Ignore".into(), + }, + ]), + vec![], + version, + ) + .unwrap(); + let req = DapRequest { + meta: DapRequestMeta { + version, + media_type: Some(DapMediaType::Report), + task_id, + taskprov_advertisement: Some(taskprov_advertisement.clone()), + }, + resource_id: (), + payload: report, + }; + + assert_eq!( + DapError::Abort(DapAbort::unsupported_extension(&task_id, &[14, 15]).unwrap()), + leader::handle_upload_req(&*t.leader, req) + .await + .unwrap_err() + ); + } + + #[tokio::test] + async fn leader_upload_taskprov_in_public_and_private_extensions() { + let version = DapVersion::Latest; + let t = Test::new(DapVersion::Latest); + + let (task_config, task_id, taskprov_advertisement) = DapTaskParameters { + version, + min_batch_size: 1, + query: DapBatchMode::LeaderSelected { + max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, + ..Default::default() + } + .to_config_with_taskprov( + b"cool task".to_vec(), + t.now, + t.leader.get_taskprov_config().unwrap(), + ) + .unwrap(); + + // Clients: Send upload request to Leader. + let hpke_config_list = [ + t.leader + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + t.helper + .get_hpke_config_for(version, Some(&task_id)) + .await + .unwrap() + .clone(), + ]; + + for _ in 0..task_config.min_batch_size { + let report = task_config + .vdaf + .produce_report_with_extensions( + &hpke_config_list, + t.now + 1, + &task_id, + DapMeasurement::U32Vec(vec![1; 10]), + Some(vec![Extension::Taskprov]), + vec![Extension::Taskprov], + version, + ) + .unwrap(); + let req = DapRequest { + meta: DapRequestMeta { + version, + media_type: Some(DapMediaType::Report), + task_id, + taskprov_advertisement: Some(taskprov_advertisement.clone()), + }, + resource_id: (), + payload: report, + }; + leader::handle_upload_req(&*t.leader, req).await.unwrap(); + } + // Collector: Request result from the Leader. + let query = Query::LeaderSelectedCurrentBatch; + leader::handle_coll_job_req(&*t.leader, &t.gen_test_coll_job_req(query, &task_id).await) + .await + .unwrap(); + + leader::process(&*t.leader, "leader.com", 100) + .await + .unwrap(); + + assert_metrics_include!(t.leader_registry, { + r#"report_counter{env="test_leader",host="leader.com",status="rejected_invalid_message"}"#: 1, + r#"inbound_request_counter{env="test_leader",host="leader.com",type="upload"}"#: 1, + }); + } + // Test multiple tasks in flight at once. async fn multi_task(version: DapVersion) { let t = Test::new(version); diff --git a/crates/daphne/src/testing/report_generator.rs b/crates/daphne/src/testing/report_generator.rs index 6163fa0e6..3dd48f4a5 100644 --- a/crates/daphne/src/testing/report_generator.rs +++ b/crates/daphne/src/testing/report_generator.rs @@ -45,7 +45,8 @@ impl ReportGenerator { measurement: &DapMeasurement, version: DapVersion, now: Time, - extensions: Vec, + public_extensions: Option>, + private_extensions: Vec, replay_reports: bool, ) -> Self { let (tx, rx) = mpsc::channel(); @@ -78,7 +79,8 @@ impl ReportGenerator { report_time_dist.sample(&mut thread_rng()), &task_id, measurement.clone(), - extensions.clone(), + public_extensions.clone(), + private_extensions.clone(), version, ) .expect("we have to panic here since we can't return the error") @@ -90,7 +92,8 @@ impl ReportGenerator { report_time_dist.sample(&mut thread_rng()), &task_id, measurement.clone(), - extensions.clone(), + public_extensions.clone(), + private_extensions.clone(), version, )? };