diff --git a/crates/dapf/src/functions/test_routes.rs b/crates/dapf/src/functions/test_routes.rs index 8dc6145ef..1b76b9324 100644 --- a/crates/dapf/src/functions/test_routes.rs +++ b/crates/dapf/src/functions/test_routes.rs @@ -5,7 +5,7 @@ //! //! [interop]: https://divergentdave.github.io/draft-dcook-ppm-dap-interop-test-design/draft-dcook-ppm-dap-interop-test-design.html -use anyhow::Context; +use anyhow::{bail, Context}; use daphne::{ hpke::{HpkeKemId, HpkeReceiverConfig}, messages::HpkeConfigList, @@ -53,12 +53,18 @@ impl HttpClient { } pub async fn delete_all_storage(&self, aggregator_url: &Url) -> anyhow::Result<()> { - self.post(aggregator_url.join("/internal/delete_all").unwrap()) + let resp = self + .post(aggregator_url.join("/internal/delete_all").unwrap()) .send() .await - .context("deleting storage")? - .error_for_status() .context("deleting storage")?; - Ok(()) + if resp.status().is_success() { + return Ok(()); + } + bail!( + "delete storage request failed. {} {}", + resp.status(), + resp.text().await? + ); } } diff --git a/crates/daphne-server/src/roles/aggregator.rs b/crates/daphne-server/src/roles/aggregator.rs index 90d5f4f68..060453b68 100644 --- a/crates/daphne-server/src/roles/aggregator.rs +++ b/crates/daphne-server/src/roles/aggregator.rs @@ -216,19 +216,22 @@ impl DapAggregator for crate::App { ) -> Result<(), DapError> { let expiration_time = task_config.not_after; - if self.service_config.role.is_leader() { - self.kv() - .put_with_expiration::( - task_id, - task_config, - expiration_time, - ) - .await - .map_err(|e| fatal_error!(err = ?e, "failed to put the a task config in kv"))?; - } else { - self.kv() - .only_cache_put::(task_id, task_config) - .await; + match self.service_config.role { + daphne::constants::DapAggregatorRole::Leader => { + self.kv() + .put_with_expiration::( + task_id, + task_config, + expiration_time, + ) + .await + .map_err(|e| fatal_error!(err = ?e, "failed to put the a task config in kv"))?; + } + daphne::constants::DapAggregatorRole::Helper => { + self.kv() + .only_cache_put::(task_id, task_config) + .await; + } } Ok(()) } diff --git a/crates/daphne-server/src/storage_proxy_connection/mod.rs b/crates/daphne-server/src/storage_proxy_connection/mod.rs index 85ff19d8f..4e5cb587f 100644 --- a/crates/daphne-server/src/storage_proxy_connection/mod.rs +++ b/crates/daphne-server/src/storage_proxy_connection/mod.rs @@ -6,9 +6,9 @@ pub(crate) mod kv; use std::fmt::Debug; use axum::http::StatusCode; -use daphne_service_utils::durable_requests::{ - bindings::{DurableMethod, DurableRequestPayload, DurableRequestPayloadExt}, - DurableRequest, ObjectIdFrom, DO_PATH_PREFIX, +use daphne_service_utils::{ + capnproto_payload::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _}, + durable_requests::{bindings::DurableMethod, DurableRequest, ObjectIdFrom, DO_PATH_PREFIX}, }; use serde::de::DeserializeOwned; @@ -95,7 +95,7 @@ impl<'d, B: DurableMethod + Debug, P: AsRef<[u8]>> RequestBuilder<'d, B, P> { } impl<'d, B: DurableMethod> RequestBuilder<'d, B, [u8; 0]> { - pub fn encode(self, payload: &T) -> RequestBuilder<'d, B, Vec> { + pub fn encode(self, payload: &T) -> RequestBuilder<'d, B, Vec> { self.with_body(payload.encode_to_bytes().unwrap()) } diff --git a/crates/daphne-service-utils/src/capnproto_payload.rs b/crates/daphne-service-utils/src/capnproto_payload.rs new file mode 100644 index 000000000..01dbbbe53 --- /dev/null +++ b/crates/daphne-service-utils/src/capnproto_payload.rs @@ -0,0 +1,54 @@ +// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +pub trait CapnprotoPayloadEncode { + fn encode_to_builder(&self) -> capnp::message::Builder; +} + +pub trait CapnprotoPayloadEncodeExt { + fn encode_to_bytes(&self) -> capnp::Result>; +} + +pub trait CapnprotoPayloadDecode { + fn decode_from_reader( + reader: capnp::message::Reader, + ) -> capnp::Result + where + Self: Sized; +} + +pub trait CapnprotoPayloadDecodeExt { + fn decode_from_bytes(bytes: &[u8]) -> capnp::Result + where + Self: Sized; +} + +impl CapnprotoPayloadEncodeExt for T +where + T: CapnprotoPayloadEncode, +{ + fn encode_to_bytes(&self) -> capnp::Result> { + let mut buf = Vec::new(); + let message = self.encode_to_builder(); + capnp::serialize_packed::write_message(&mut buf, &message)?; + Ok(buf) + } +} + +impl CapnprotoPayloadDecodeExt for T +where + T: CapnprotoPayloadDecode, +{ + fn decode_from_bytes(bytes: &[u8]) -> capnp::Result + where + Self: Sized, + { + let mut cursor = std::io::Cursor::new(bytes); + let reader = capnp::serialize_packed::read_message( + &mut cursor, + capnp::message::ReaderOptions::new(), + )?; + + T::decode_from_reader(reader) + } +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs index 4cac9c98b..51bf41fea 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs @@ -11,11 +11,11 @@ use daphne::{ use serde::{Deserialize, Serialize}; use crate::{ + capnproto_payload::{CapnprotoPayloadDecode, CapnprotoPayloadEncode}, durable_request_capnp::{aggregate_store_merge_req, dap_aggregate_share}, durable_requests::ObjectIdFrom, }; -use super::DurableRequestPayload; use prio::{ field::{FieldElement, FieldError}, vdaf::AggregateShare, @@ -68,7 +68,7 @@ pub struct AggregateStoreMergeOptions { pub skip_replay_protection: bool, } -impl DurableRequestPayload for AggregateStoreMergeReq { +impl CapnprotoPayloadEncode for AggregateStoreMergeReq { fn encode_to_builder(&self) -> capnp::message::Builder { let Self { contained_reports, @@ -162,7 +162,9 @@ impl DurableRequestPayload for AggregateStoreMergeReq { } message } +} +impl CapnprotoPayloadDecode for AggregateStoreMergeReq { fn decode_from_reader( reader: capnp::message::Reader, ) -> capnp::Result { @@ -285,7 +287,9 @@ mod test { }; use rand::{thread_rng, Rng}; - use crate::durable_requests::bindings::DurableRequestPayloadExt; + use crate::capnproto_payload::{ + CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _, + }; use super::*; diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs b/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs index 3807e9fca..911327237 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs @@ -37,48 +37,6 @@ pub trait DurableMethod { fn name(params: Self::NameParameters<'_>) -> ObjectIdFrom; } -pub trait DurableRequestPayload { - fn decode_from_reader( - reader: capnp::message::Reader, - ) -> capnp::Result - where - Self: Sized; - - fn encode_to_builder(&self) -> capnp::message::Builder; -} - -pub trait DurableRequestPayloadExt { - fn decode_from_bytes(bytes: &[u8]) -> capnp::Result - where - Self: Sized; - fn encode_to_bytes(&self) -> capnp::Result>; -} - -impl DurableRequestPayloadExt for T -where - T: DurableRequestPayload, -{ - fn encode_to_bytes(&self) -> capnp::Result> { - let mut buf = Vec::new(); - let message = self.encode_to_builder(); - capnp::serialize_packed::write_message(&mut buf, &message)?; - Ok(buf) - } - - fn decode_from_bytes(bytes: &[u8]) -> capnp::Result - where - Self: Sized, - { - let mut cursor = std::io::Cursor::new(bytes); - let reader = capnp::serialize_packed::read_message( - &mut cursor, - capnp::message::ReaderOptions::new(), - )?; - - T::decode_from_reader(reader) - } -} - macro_rules! define_do_binding { ( const BINDING = $binding:literal; diff --git a/crates/daphne-service-utils/src/lib.rs b/crates/daphne-service-utils/src/lib.rs index 9172051cc..92e038461 100644 --- a/crates/daphne-service-utils/src/lib.rs +++ b/crates/daphne-service-utils/src/lib.rs @@ -5,6 +5,8 @@ pub mod bearer_token; #[cfg(feature = "durable_requests")] +pub mod capnproto_payload; +#[cfg(feature = "durable_requests")] pub mod durable_requests; pub mod http_headers; #[cfg(feature = "test-utils")] diff --git a/crates/daphne-worker/src/durable/mod.rs b/crates/daphne-worker/src/durable/mod.rs index 9c476cd2c..b9a61c0b9 100644 --- a/crates/daphne-worker/src/durable/mod.rs +++ b/crates/daphne-worker/src/durable/mod.rs @@ -24,8 +24,9 @@ pub(crate) mod aggregate_store; pub(crate) mod test_state_cleaner; use crate::tracing_utils::shorten_paths; -use daphne_service_utils::durable_requests::bindings::{ - DurableMethod, DurableRequestPayload, DurableRequestPayloadExt, +use daphne_service_utils::{ + capnproto_payload::{CapnprotoPayloadDecode, CapnprotoPayloadDecodeExt}, + durable_requests::bindings::DurableMethod, }; use serde::{Deserialize, Serialize}; use tracing::info_span; @@ -209,7 +210,7 @@ pub(crate) async fn state_set_if_not_exists Deserialize<'a> + Seriali async fn req_parse(req: &mut Request) -> Result where - T: DurableRequestPayload, + T: CapnprotoPayloadDecode, { T::decode_from_bytes(&req.bytes().await?).map_err(|e| Error::RustError(e.to_string())) } diff --git a/crates/daphne/src/constants.rs b/crates/daphne/src/constants.rs index dc709fe85..c6d77218a 100644 --- a/crates/daphne/src/constants.rs +++ b/crates/daphne/src/constants.rs @@ -75,16 +75,6 @@ pub enum DapRole { Helper, } -impl DapRole { - pub fn is_leader(self) -> bool { - self == Self::Leader - } - - pub fn is_helper(self) -> bool { - self == Self::Helper - } -} - impl FromStr for DapRole { type Err = String; @@ -119,16 +109,6 @@ pub enum DapAggregatorRole { Helper, } -impl DapAggregatorRole { - pub fn is_leader(self) -> bool { - self == Self::Leader - } - - pub fn is_helper(self) -> bool { - self == Self::Helper - } -} - impl FromStr for DapAggregatorRole { type Err = String; diff --git a/crates/daphne/src/lib.rs b/crates/daphne/src/lib.rs index 875020cfc..d7499078a 100644 --- a/crates/daphne/src/lib.rs +++ b/crates/daphne/src/lib.rs @@ -93,7 +93,9 @@ use url::Url; use vdaf::mastic::MasticWeight; pub use messages::request::{DapRequest, DapRequestMeta, DapResponse}; -pub use protocol::report_init::{InitializedReport, WithPeerPrepShare}; +pub use protocol::report_init::{ + InitializedReport, PartialDapTaskConfigForReportInit, WithPeerPrepShare, +}; /// DAP version used for a task. #[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)] diff --git a/crates/daphne/src/pine/vdaf.rs b/crates/daphne/src/pine/vdaf.rs index 025773eec..74bc091ef 100644 --- a/crates/daphne/src/pine/vdaf.rs +++ b/crates/daphne/src/pine/vdaf.rs @@ -383,21 +383,22 @@ where X: Xof, { fn decode_with_param( - (pine, is_leader): &(&Pine, bool), + (pine, role): &(&Pine, bool), bytes: &mut std::io::Cursor<&[u8]>, ) -> Result { - let (gradient_share, meas_share_seed) = if *is_leader { - ( + let (gradient_share, meas_share_seed) = match role { + true => ( std::iter::repeat_with(|| F::decode(bytes)) .take(pine.flp.dimension) .collect::, _>>()?, None, - ) - } else { - let seed = Seed::decode(bytes)?; - let mut gradient_share = pine.helper_meas_share(seed.as_ref()); - gradient_share.truncate(pine.flp.dimension); - (gradient_share, Some(seed)) + ), + false => { + let seed = Seed::decode(bytes)?; + let mut gradient_share = pine.helper_meas_share(seed.as_ref()); + gradient_share.truncate(pine.flp.dimension); + (gradient_share, Some(seed)) + } }; Ok(Self { diff --git a/crates/daphne/src/protocol/report_init.rs b/crates/daphne/src/protocol/report_init.rs index fd4a9a812..c38b5fdd9 100644 --- a/crates/daphne/src/protocol/report_init.rs +++ b/crates/daphne/src/protocol/report_init.rs @@ -8,11 +8,14 @@ use crate::{ self, Extension, PlaintextInputShare, ReportError, ReportMetadata, ReportShare, TaskId, }, protocol::{decode_ping_pong_framed, no_duplicates, PingPongMessageType}, - vdaf::{VdafPrepShare, VdafPrepState}, - DapAggregationParam, DapError, DapTaskConfig, + vdaf::{VdafConfig, VdafPrepShare, VdafPrepState, VdafVerifyKey}, + DapAggregationParam, DapError, DapTaskConfig, DapVersion, }; use prio::codec::{CodecError, ParameterizedDecode as _}; -use std::ops::{Deref, Range}; +use std::{ + borrow::Cow, + ops::{Deref, Range}, +}; /// Report state during aggregation initialization after the VDAF preparation step. /// @@ -38,6 +41,12 @@ pub enum InitializedReport { pub struct WithPeerPrepShare(Vec); +impl From> for WithPeerPrepShare { + fn from(value: Vec) -> Self { + Self(value) + } +} + impl Deref for WithPeerPrepShare { type Target = Vec; fn deref(&self) -> &Self::Target { @@ -46,11 +55,11 @@ impl Deref for WithPeerPrepShare { } impl InitializedReport<()> { - pub fn from_client( + pub fn from_client<'s>( decrypter: &impl HpkeDecrypter, valid_report_range: Range, task_id: &TaskId, - task_config: &DapTaskConfig, + task_config: impl Into>, report_share: ReportShare, agg_param: &DapAggregationParam, ) -> Result { @@ -67,11 +76,11 @@ impl InitializedReport<()> { } impl InitializedReport { - pub fn from_leader( + pub fn from_leader<'s>( decrypter: &impl HpkeDecrypter, valid_report_range: Range, task_id: &TaskId, - task_config: &DapTaskConfig, + task_config: impl Into>, report_share: ReportShare, prep_init_payload: Vec, agg_param: &DapAggregationParam, @@ -88,12 +97,44 @@ impl InitializedReport { } } +impl<'s> From<&'s DapTaskConfig> for PartialDapTaskConfigForReportInit<'s> { + fn from(config: &'s DapTaskConfig) -> Self { + PartialDapTaskConfigForReportInit { + not_after: config.not_after, + method_is_taskprov: config.method_is_taskprov(), + version: config.version, + vdaf: Cow::Borrowed(&config.vdaf), + vdaf_verify_key: Cow::Borrowed(&config.vdaf_verify_key), + } + } +} + +impl<'s> From<&'s PartialDapTaskConfigForReportInit<'_>> for PartialDapTaskConfigForReportInit<'s> { + fn from(config: &'s PartialDapTaskConfigForReportInit<'_>) -> Self { + Self { + not_after: config.not_after, + method_is_taskprov: config.method_is_taskprov, + version: config.version, + vdaf: Cow::Borrowed(&config.vdaf), + vdaf_verify_key: Cow::Borrowed(&config.vdaf_verify_key), + } + } +} + +pub struct PartialDapTaskConfigForReportInit<'s> { + pub not_after: messages::Time, + pub method_is_taskprov: bool, + pub version: DapVersion, + pub vdaf: Cow<'s, VdafConfig>, + pub vdaf_verify_key: Cow<'s, VdafVerifyKey>, +} + impl

InitializedReport

{ - fn initialize( + fn initialize<'s, S>( decrypter: &impl HpkeDecrypter, valid_report_range: Range, task_id: &TaskId, - task_config: &DapTaskConfig, + task_config: impl Into>, report_share: ReportShare, prep_init_payload: S, agg_param: &DapAggregationParam, @@ -101,6 +142,7 @@ impl

InitializedReport

{ where S: PrepInitPayload, { + let task_config = task_config.into(); macro_rules! reject { ($failure:ident) => { return Ok(InitializedReport::Rejected { @@ -158,7 +200,7 @@ impl

InitializedReport

{ let mut taskprov_indicated = false; for extension in extensions { match extension { - Extension::Taskprov { .. } if task_config.method_is_taskprov() => { + Extension::Taskprov { .. } if task_config.method_is_taskprov => { taskprov_indicated = true; } @@ -167,7 +209,7 @@ impl

InitializedReport

{ } } - if task_config.method_is_taskprov() && !taskprov_indicated { + if task_config.method_is_taskprov && !taskprov_indicated { // taskprov: If the task configuration method is taskprov, then we expect each // report to indicate support. reject!(InvalidMessage); diff --git a/crates/daphne/src/vdaf/mod.rs b/crates/daphne/src/vdaf/mod.rs index 60c36898e..2046f8528 100644 --- a/crates/daphne/src/vdaf/mod.rs +++ b/crates/daphne/src/vdaf/mod.rs @@ -28,10 +28,7 @@ use prio::{ #[cfg(any(test, feature = "test-utils", feature = "experimental"))] use prio::field::FieldElement; use prio_draft09::{ - codec::{ - CodecError as CodecErrorDraft09, Encode as EncodeDraft09, - ParameterizedDecode as ParameterizedDecodeDraft09, - }, + codec::{CodecError as CodecErrorDraft09, Encode as EncodeDraft09}, field::{ Field128 as Field128Draft09, Field64 as Field64Draft09, FieldPrio2 as FieldPrio2Draft09, }, @@ -51,9 +48,21 @@ use std::io::Read; #[cfg(feature = "experimental")] pub use self::mastic::MasticWeightConfig; +use crate::constants::DapAggregatorRole; +use prio_draft09::codec::ParameterizedDecode as _; const CTX_STRING_PREFIX: &[u8] = b"dap-13"; +impl DapAggregatorRole { + /// The numeric identifier of the role of the aggregator decoding the vdaf. + fn as_aggregator_id(self) -> usize { + match self { + DapAggregatorRole::Leader => 0, + DapAggregatorRole::Helper => 1, + } + } +} + #[derive(Debug, thiserror::Error)] pub(crate) enum VdafError { #[error("{0}")] @@ -96,17 +105,6 @@ impl std::str::FromStr for VdafConfig { } } -pub(crate) fn from_codec_error(c: CodecErrorDraft09) -> CodecError { - match c { - CodecErrorDraft09::Io(x) => CodecError::Io(x), - CodecErrorDraft09::BytesLeftOver(u) => CodecError::BytesLeftOver(u), - CodecErrorDraft09::LengthPrefixTooBig(u) => CodecError::LengthPrefixTooBig(u), - CodecErrorDraft09::LengthPrefixOverflow => CodecError::LengthPrefixOverflow, - CodecErrorDraft09::Other(x) => CodecError::Other(x), - _ => CodecError::UnexpectedValue, - } -} - impl std::fmt::Display for VdafConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -425,6 +423,17 @@ impl AsMut<[u8]> for VdafVerifyKey { } } +fn upgrade_codec_error(error: CodecErrorDraft09) -> CodecError { + match error { + CodecErrorDraft09::Io(error) => CodecError::Io(error), + CodecErrorDraft09::BytesLeftOver(n) => CodecError::BytesLeftOver(n), + CodecErrorDraft09::LengthPrefixTooBig(n) => CodecError::LengthPrefixTooBig(n), + CodecErrorDraft09::LengthPrefixOverflow => CodecError::LengthPrefixOverflow, + CodecErrorDraft09::Other(error) => CodecError::Other(error), + _ => CodecError::UnexpectedValue, + } +} + /// VDAF preparation state. #[derive(Clone)] #[cfg_attr(any(test, feature = "test-utils"), derive(Debug, Eq, PartialEq))] @@ -432,7 +441,6 @@ pub enum VdafPrepState { Prio2(Prio2PrepareState), Prio3Draft09Field64HmacSha256Aes128(Prio3Draft09PrepareState), Prio3Field64(Prio3PrepareState), - Prio3Field64HmacSha256Aes128(Prio3PrepareState), Prio3Field128(Prio3PrepareState), #[cfg(feature = "experimental")] Mastic { @@ -442,6 +450,43 @@ pub enum VdafPrepState { Pine32HmacSha256Aes128(PinePrepState), } +impl Encode for VdafPrepState { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + match self { + Self::Prio3Draft09Field64HmacSha256Aes128(state) => { + state.encode(bytes).map_err(upgrade_codec_error) + } + + Self::Prio3Field64(state) => state.encode(bytes), + Self::Prio3Field128(state) => state.encode(bytes), + + Self::Prio2(state) => state.encode(bytes), + Self::Pine64HmacSha256Aes128(state) => state.encode(bytes).map_err(upgrade_codec_error), + Self::Pine32HmacSha256Aes128(state) => state.encode(bytes).map_err(upgrade_codec_error), + #[cfg(feature = "experimental")] + Self::Mastic { .. } => todo!("encoding of prep state is not implemented"), + } + } +} + +impl ParameterizedDecode<(&VdafConfig, DapAggregatorRole)> for VdafPrepState { + fn decode_with_param( + (vdaf_config, role): &(&VdafConfig, DapAggregatorRole), + bytes: &mut std::io::Cursor<&[u8]>, + ) -> Result { + match vdaf_config { + VdafConfig::Prio3(prio3_config) => prio3::decode_prep_state(prio3_config, *role, bytes), + VdafConfig::Prio2 { dimension } => prio2::decode_prep_state(*dimension, *role, bytes), + VdafConfig::Pine(config) => pine::decode_prep_state(config, *role, bytes), + #[cfg(feature = "experimental")] + VdafConfig::Mastic { .. } => { + todo!("decoding of mastic prep state is not implemented") + } + } + .map_err(|e| CodecError::Other(Box::new(e))) + } +} + #[cfg(any(test, feature = "test-utils"))] impl deepsize::DeepSizeOf for VdafPrepState { fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { @@ -453,7 +498,6 @@ impl deepsize::DeepSizeOf for VdafPrepState { Self::Prio2(_) | Self::Prio3Draft09Field64HmacSha256Aes128(_) | Self::Prio3Field64(_) - | Self::Prio3Field64HmacSha256Aes128(_) | Self::Prio3Field128(_) | Self::Pine64HmacSha256Aes128(_) | Self::Pine32HmacSha256Aes128(_) => 0, @@ -470,7 +514,6 @@ pub enum VdafPrepShare { Prio2(Prio2PrepareShare), Prio3Draft09Field64HmacSha256Aes128(Prio3Draft09PrepareShare), Prio3Field64(Prio3PrepareShare), - Prio3Field64HmacSha256Aes128(Prio3PrepareShare), Prio3Field128(Prio3PrepareShare), #[cfg(feature = "experimental")] Mastic(Field64), @@ -491,7 +534,6 @@ impl deepsize::DeepSizeOf for VdafPrepShare { // type. Self::Prio3Draft09Field64HmacSha256Aes128(..) | Self::Prio3Field64(..) - | Self::Prio3Field64HmacSha256Aes128(..) | Self::Prio3Field128(..) | Self::Pine64HmacSha256Aes128(_) | Self::Pine32HmacSha256Aes128(_) => 0, @@ -505,16 +547,15 @@ impl Encode for VdafPrepShare { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { match self { Self::Prio3Draft09Field64HmacSha256Aes128(share) => { - share.encode(bytes).map_err(from_codec_error) + share.encode(bytes).map_err(upgrade_codec_error) } Self::Prio3Field64(share) => share.encode(bytes), - Self::Prio3Field64HmacSha256Aes128(share) => share.encode(bytes), Self::Prio3Field128(share) => share.encode(bytes), Self::Prio2(share) => share.encode(bytes), #[cfg(feature = "experimental")] Self::Mastic(share) => share.encode(bytes), - Self::Pine64HmacSha256Aes128(share) => share.encode(bytes).map_err(from_codec_error), - Self::Pine32HmacSha256Aes128(share) => share.encode(bytes).map_err(from_codec_error), + Self::Pine64HmacSha256Aes128(share) => share.encode(bytes).map_err(upgrade_codec_error), + Self::Pine32HmacSha256Aes128(share) => share.encode(bytes).map_err(upgrade_codec_error), } } } @@ -528,17 +569,12 @@ impl ParameterizedDecode for VdafPrepShare { VdafPrepState::Prio3Draft09Field64HmacSha256Aes128(state) => { Ok(VdafPrepShare::Prio3Draft09Field64HmacSha256Aes128( Prio3Draft09PrepareShare::decode_with_param(state, bytes) - .map_err(from_codec_error)?, + .map_err(upgrade_codec_error)?, )) } VdafPrepState::Prio3Field64(state) => Ok(VdafPrepShare::Prio3Field64( Prio3PrepareShare::decode_with_param(state, bytes)?, )), - VdafPrepState::Prio3Field64HmacSha256Aes128(state) => { - Ok(VdafPrepShare::Prio3Field64HmacSha256Aes128( - Prio3PrepareShare::decode_with_param(state, bytes)?, - )) - } VdafPrepState::Prio3Field128(state) => Ok(VdafPrepShare::Prio3Field128( Prio3PrepareShare::decode_with_param(state, bytes)?, )), @@ -552,13 +588,13 @@ impl ParameterizedDecode for VdafPrepShare { VdafPrepState::Pine64HmacSha256Aes128(state) => { Ok(VdafPrepShare::Pine64HmacSha256Aes128( crate::pine::msg::PrepShare::decode_with_param(state, bytes) - .map_err(from_codec_error)?, + .map_err(upgrade_codec_error)?, )) } VdafPrepState::Pine32HmacSha256Aes128(state) => { Ok(VdafPrepShare::Pine32HmacSha256Aes128( crate::pine::msg::PrepShare::decode_with_param(state, bytes) - .map_err(from_codec_error)?, + .map_err(upgrade_codec_error)?, )) } } @@ -594,13 +630,13 @@ impl Encode for VdafAggregateShare { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { match self { VdafAggregateShare::Field32Draft09(agg_share) => { - agg_share.encode(bytes).map_err(from_codec_error) + agg_share.encode(bytes).map_err(upgrade_codec_error) } VdafAggregateShare::Field64Draft09(agg_share) => { - agg_share.encode(bytes).map_err(from_codec_error) + agg_share.encode(bytes).map_err(upgrade_codec_error) } VdafAggregateShare::Field128Draft09(agg_share) => { - agg_share.encode(bytes).map_err(from_codec_error) + agg_share.encode(bytes).map_err(upgrade_codec_error) } VdafAggregateShare::Field32(agg_share) => agg_share.encode(bytes), VdafAggregateShare::Field64(agg_share) => agg_share.encode(bytes), @@ -625,7 +661,7 @@ impl VdafConfig { Self::Prio2 { .. } | Self::Prio3(..) => Ok(VdafVerifyKey::L32( <[u8; 32]>::try_from(bytes) .map_err(|e| CodecErrorDraft09::Other(Box::new(e))) - .map_err(from_codec_error)?, + .map_err(upgrade_codec_error)?, )), #[cfg(feature = "experimental")] Self::Mastic { .. } => Ok(VdafVerifyKey::L16( @@ -634,7 +670,7 @@ impl VdafConfig { Self::Pine(..) => Ok(VdafVerifyKey::L32( <[u8; 32]>::try_from(bytes) .map_err(|e| CodecErrorDraft09::Other(Box::new(e))) - .map_err(from_codec_error)?, + .map_err(upgrade_codec_error)?, )), } } diff --git a/crates/daphne/src/vdaf/pine.rs b/crates/daphne/src/vdaf/pine.rs index 2fc0352ba..e1e0d8dc2 100644 --- a/crates/daphne/src/vdaf/pine.rs +++ b/crates/daphne/src/vdaf/pine.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause use crate::{ + constants::DapAggregatorRole, fatal_error, messages::taskprov::{ VDAF_TYPE_PINE_FIELD32_HMAC_SHA256_AES128, VDAF_TYPE_PINE_FIELD64_HMAC_SHA256_AES128, @@ -20,6 +21,7 @@ use prio_draft09::{ }, }; use serde::{Deserialize, Serialize}; +use std::io::Cursor; pub(crate) fn pine32_hmac_sha256_aes128( param: &PineParam, @@ -234,6 +236,33 @@ fn prep_init, const SEED_SIZE: usi Ok(vdaf.prepare_init(verify_key, agg_id, &(), nonce, &public_share, &input_share)?) } +pub fn decode_prep_state( + config: &PineConfig, + role: DapAggregatorRole, + cursor: &mut Cursor<&[u8]>, +) -> Result { + Ok(match config { + PineConfig::Field32HmacSha256Aes128 { param } => VdafPrepState::Pine32HmacSha256Aes128( + PinePrepState::::decode_with_param( + &( + &pine32_hmac_sha256_aes128(param)?, + role == DapAggregatorRole::Leader, + ), + cursor, + )?, + ), + PineConfig::Field64HmacSha256Aes128 { param } => { + VdafPrepState::Pine64HmacSha256Aes128(PinePrepState::::decode_with_param( + &( + &pine64_hmac_sha256_aes128(param)?, + role == DapAggregatorRole::Leader, + ), + cursor, + )?) + } + }) +} + #[cfg(test)] mod test { use crate::{ diff --git a/crates/daphne/src/vdaf/prio2.rs b/crates/daphne/src/vdaf/prio2.rs index c0cca4e22..d8086cee9 100644 --- a/crates/daphne/src/vdaf/prio2.rs +++ b/crates/daphne/src/vdaf/prio2.rs @@ -5,17 +5,18 @@ //! [VDAF](https://datatracker.ietf.org/doc/draft-patton-cfrg-vdaf/). use crate::{ - fatal_error, vdaf::VdafError, DapAggregateResult, DapMeasurement, VdafAggregateShare, - VdafPrepShare, VdafPrepState, VdafVerifyKey, + constants::DapAggregatorRole, fatal_error, vdaf::VdafError, DapAggregateResult, DapMeasurement, + VdafAggregateShare, VdafPrepShare, VdafPrepState, VdafVerifyKey, }; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, field::FieldPrio2, vdaf::{ - prio2::{Prio2, Prio2PrepareShare}, + prio2::{Prio2, Prio2PrepareShare, Prio2PrepareState}, AggregateShare, Aggregator, Client, Collector, PrepareTransition, Share, Vdaf, }, }; +use std::io::Cursor; /// Split the given measurement into a sequence of encoded input shares. pub(crate) fn prio2_shard( @@ -143,6 +144,21 @@ pub(crate) fn prio2_prep_finish( Ok(agg_share) } +/// Parse our prep state. +pub(crate) fn decode_prep_state( + dimension: usize, + role: DapAggregatorRole, + bytes: &mut Cursor<&[u8]>, +) -> Result { + let vdaf = Prio2::new(dimension).map_err(|e| { + VdafError::Dap(fatal_error!(err = ?e, "failed to create prio2 from {dimension}")) + })?; + Ok(VdafPrepState::Prio2(Prio2PrepareState::decode_with_param( + &(&vdaf, role.as_aggregator_id()), + bytes, + )?)) +} + /// Interpret `encoded_agg_shares` as a sequence of encoded aggregate shares and unshard them. pub(crate) fn prio2_unshard>>( dimension: usize, diff --git a/crates/daphne/src/vdaf/prio3.rs b/crates/daphne/src/vdaf/prio3.rs index ea88d95da..bce7f3a4a 100644 --- a/crates/daphne/src/vdaf/prio3.rs +++ b/crates/daphne/src/vdaf/prio3.rs @@ -4,6 +4,7 @@ //! Parameters for the [Prio3 VDAF](https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/13/). use crate::{ + constants::DapAggregatorRole, fatal_error, messages::TaskId, vdaf::{draft09, VdafError, VdafVerifyKey}, @@ -11,6 +12,8 @@ use crate::{ VdafPrepState, }; +use super::{prep_finish, prep_finish_from_shares, shard_then_encode, unshard}; + use prio::{ codec::ParameterizedDecode, flp::Type, @@ -20,8 +23,7 @@ use prio::{ Aggregator, }, }; - -use super::{prep_finish, prep_finish_from_shares, shard_then_encode, unshard}; +use std::io::Cursor; const CTX_STRING_PREFIX: &[u8] = b"dap-13"; @@ -560,6 +562,74 @@ impl Prio3Config { } } +/// Parse our prep state. +pub(crate) fn decode_prep_state( + config: &Prio3Config, + role: DapAggregatorRole, + bytes: &mut Cursor<&[u8]>, +) -> Result { + let agg_id = role.as_aggregator_id(); + match config { + Prio3Config::Count => { + let vdaf = Prio3::new_count(2).map_err(|e| { + VdafError::Dap( + fatal_error!(err = ?e, "failed to create prio3 count from num_aggregators(2)"), + ) + })?; + Ok(VdafPrepState::Prio3Field64( + Prio3PrepareState::decode_with_param(&(&vdaf, agg_id), bytes)?, + )) + } + Prio3Config::Histogram { + length, + chunk_length, + } => { + let vdaf = Prio3::new_histogram(2, *length, *chunk_length) + .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 histogram from num_aggregators(2), length({length}), chunk_length({chunk_length})")))?; + Ok(VdafPrepState::Prio3Field128( + Prio3PrepareState::decode_with_param(&(&vdaf, agg_id), bytes)?, + )) + } + Prio3Config::Sum { max_measurement } => { + let vdaf = + Prio3::new_sum(2, *max_measurement) + .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum from num_aggregators(2), max_measurement({max_measurement})")))?; + Ok(VdafPrepState::Prio3Field64( + Prio3PrepareState::decode_with_param(&(&vdaf, agg_id), bytes) + .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum from num_aggregators(2), max_measurement({max_measurement})")))?, + )) + } + Prio3Config::SumVec { + bits, + length, + chunk_length, + } => { + let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length) + .map_err(|e| VdafError::Dap(fatal_error!(err = ?e, "failed to create prio3 sum vec from num_aggregators(2), bits({bits}), length({length}), chunk_length({chunk_length})")))?; + Ok(VdafPrepState::Prio3Field128( + Prio3PrepareState::decode_with_param(&(&vdaf, agg_id), bytes)?, + )) + } + Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 { + bits, + length, + chunk_length, + num_proofs, + } => { + use prio_draft09::{codec::ParameterizedDecode, vdaf::prio3::Prio3PrepareState}; + let vdaf = super::draft09::new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( + *bits, + *length, + *chunk_length, + *num_proofs, + )?; + Ok(VdafPrepState::Prio3Draft09Field64HmacSha256Aes128( + Prio3PrepareState::decode_with_param(&(&vdaf, agg_id), bytes)?, + )) + } + } +} + #[cfg(test)] mod test {