diff --git a/.github/workflows/daphneci.yml b/.github/workflows/daphneci.yml index 68de16dfa..cf23ddd34 100644 --- a/.github/workflows/daphneci.yml +++ b/.github/workflows/daphneci.yml @@ -19,7 +19,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: 1.83.0 components: clippy, rustfmt override: true - name: Machete diff --git a/crates/daphne-server/src/roles/helper.rs b/crates/daphne-server/src/roles/helper.rs index f312c7441..5257d04ce 100644 --- a/crates/daphne-server/src/roles/helper.rs +++ b/crates/daphne-server/src/roles/helper.rs @@ -1,8 +1,23 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use axum::async_trait; -use daphne::roles::DapHelper; +use daphne::{ + messages::{request::AggregationJobRequestHash, AggregationJobId, TaskId}, + roles::DapHelper, + DapError, DapVersion, +}; #[async_trait] -impl DapHelper for crate::App {} +impl DapHelper for crate::App { + async fn assert_agg_job_is_immutable( + &self, + _id: AggregationJobId, + _version: DapVersion, + _task_id: &TaskId, + _req: &AggregationJobRequestHash, + ) -> Result<(), DapError> { + // the server implementation can't check for this + Ok(()) + } +} diff --git a/crates/daphne-server/src/router/extractor.rs b/crates/daphne-server/src/router/extractor.rs index e9230d1b2..eccd4920d 100644 --- a/crates/daphne-server/src/router/extractor.rs +++ b/crates/daphne-server/src/router/extractor.rs @@ -13,7 +13,7 @@ use daphne::{ error::DapAbort, fatal_error, messages::{ - request::{CollectionPollReq, RequestBody}, + request::{CollectionPollReq, HashedAggregationJobReq, RequestBody}, taskprov::TaskprovAdvertisement, AggregateShareReq, AggregationJobInitReq, CollectionReq, Report, TaskId, }, @@ -60,6 +60,17 @@ impl_decode_from_dap_http_body!( CollectionReq, ); +impl DecodeFromDapHttpBody for HashedAggregationJobReq { + fn decode_from_http_body(bytes: Bytes, meta: &DapRequestMeta) -> Result { + let mut cursor = Cursor::new(bytes.as_ref()); + // Check that media type matches. + meta.get_checked_media_type(DapMediaType::AggregationJobInitReq)?; + // Decode the body + HashedAggregationJobReq::decode_with_param(&meta.version, &mut cursor) + .map_err(|e| DapAbort::from_codec_error(e, meta.task_id)) + } +} + /// Using `()` ignores the body of a request. impl DecodeFromDapHttpBody for CollectionPollReq { fn decode_from_http_body(_bytes: Bytes, _meta: &DapRequestMeta) -> Result { diff --git a/crates/daphne-server/src/router/helper.rs b/crates/daphne-server/src/router/helper.rs index d6f8eced3..bb7d9c4eb 100644 --- a/crates/daphne-server/src/router/helper.rs +++ b/crates/daphne-server/src/router/helper.rs @@ -8,7 +8,7 @@ use axum::{ routing::{post, put}, }; use daphne::{ - messages::{AggregateShareReq, AggregationJobInitReq}, + messages::{request::HashedAggregationJobReq, AggregateShareReq}, roles::{helper, DapHelper}, }; use http::StatusCode; @@ -38,7 +38,7 @@ pub(super) fn add_helper_routes(router: super::Router) -> super::Router>, - DapRequestExtractor(req): DapRequestExtractor, + DapRequestExtractor(req): DapRequestExtractor, ) -> AxumDapResponse { let timer = std::time::Instant::now(); diff --git a/crates/daphne-service-utils/Cargo.toml b/crates/daphne-service-utils/Cargo.toml index 2c9d50b00..f906bf3c7 100644 --- a/crates/daphne-service-utils/Cargo.toml +++ b/crates/daphne-service-utils/Cargo.toml @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause [package] @@ -28,7 +28,12 @@ capnpc = { workspace = true, optional = true } [features] test-utils = ["dep:url", "daphne/prometheus", "daphne/test-utils"] -durable_requests = ["dep:capnp", "dep:capnpc", "dep:prio_draft09", "dep:prio"] +durable_requests = [ + "dep:capnp", + "dep:capnpc", + "dep:prio_draft09", + "dep:prio", +] experimental = ["daphne/experimental"] [lints] diff --git a/crates/daphne-service-utils/build.rs b/crates/daphne-service-utils/build.rs index 6b9d9f13a..2e78bf6aa 100644 --- a/crates/daphne-service-utils/build.rs +++ b/crates/daphne-service-utils/build.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause fn main() { @@ -6,6 +6,7 @@ fn main() { ::capnpc::CompilerCommand::new() .file("./src/capnproto/base.capnp") .file("./src/durable_requests/durable_request.capnp") + .file("./src/durable_requests/bindings/aggregation_job_store.capnp") .run() .expect("compiling schema"); } diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.capnp b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.capnp new file mode 100644 index 000000000..cbbb64839 --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.capnp @@ -0,0 +1,11 @@ +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0xa11edd1197dbcf0b; + +using Base = import "../../capnproto/base.capnp"; + +struct NewJobRequest { + id @0 :Base.AggregationJobId; + aggJobHash @1 :Data; +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs new file mode 100644 index 000000000..9271f707d --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs @@ -0,0 +1,63 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use crate::{ + aggregation_job_store_capnp::new_job_request, + capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadEncode}, + durable_requests::ObjectIdFrom, +}; +use daphne::{ + messages::{AggregationJobId, TaskId}, + DapVersion, +}; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; + +super::define_do_binding! { + const BINDING = "AGGREGATION_JOB_STORE"; + + enum Command { + NewJob = "/new-job", + ListJobIds = "/job-ids", + } + + fn name((version, task_id): (DapVersion, &'n TaskId)) -> ObjectIdFrom { + ObjectIdFrom::Name(format!("{version}/task/{task_id}")) + } +} + +#[derive(Debug)] +pub struct NewJobRequest<'h> { + pub id: AggregationJobId, + pub agg_job_hash: Cow<'h, [u8]>, +} + +impl CapnprotoPayloadEncode for NewJobRequest<'_> { + type Builder<'a> = new_job_request::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + self.id.encode_to_builder(builder.reborrow().init_id()); + builder.set_agg_job_hash(&self.agg_job_hash); + } +} + +impl CapnprotoPayloadDecode for NewJobRequest<'static> { + type Reader<'a> = new_job_request::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + Ok(Self { + id: <_>::decode_from_reader(reader.get_id()?)?, + agg_job_hash: reader.get_agg_job_hash()?.to_vec().into(), + }) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum NewJobResponse { + Ok, + /// Request would change an existing aggregation job's parameters. + IllegalJobParameters, +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs b/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs index 911327237..8030e9b52 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause //! This module defines the durable objects' binding and methods as implementors of the @@ -7,6 +7,7 @@ //! It also defines types that are used as the body of requests sent to these objects. mod aggregate_store; +pub mod aggregation_job_store; #[cfg(feature = "test-utils")] mod test_state_cleaner; diff --git a/crates/daphne-service-utils/src/lib.rs b/crates/daphne-service-utils/src/lib.rs index fbd06364b..c31aabad0 100644 --- a/crates/daphne-service-utils/src/lib.rs +++ b/crates/daphne-service-utils/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause #![cfg_attr(not(test), deny(unused_crate_dependencies))] @@ -32,3 +32,14 @@ mod durable_request_capnp { "/src/durable_requests/durable_request_capnp.rs" )); } + +#[cfg(feature = "durable_requests")] +mod aggregation_job_store_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + include!(concat!( + env!("OUT_DIR"), + "/src/durable_requests/bindings/aggregation_job_store_capnp.rs" + )); +} diff --git a/crates/daphne-worker-test/src/durable.rs b/crates/daphne-worker-test/src/durable.rs index 7c8d5bc89..6fb72a456 100644 --- a/crates/daphne-worker-test/src/durable.rs +++ b/crates/daphne-worker-test/src/durable.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use daphne_worker::durable::{self, instantiate_durable_object}; @@ -10,3 +10,11 @@ instantiate_durable_object! { daphne_worker::tracing_utils::initialize_tracing(env); } } + +instantiate_durable_object! { + struct AggregationJobStore < durable::AggregationJobStore; + + fn init_user_data(_state: State, env: Env) { + daphne_worker::tracing_utils::initialize_tracing(env); + } +} diff --git a/crates/daphne-worker-test/wrangler.aggregator.toml b/crates/daphne-worker-test/wrangler.aggregator.toml index 7173056ef..98e0b37ce 100644 --- a/crates/daphne-worker-test/wrangler.aggregator.toml +++ b/crates/daphne-worker-test/wrangler.aggregator.toml @@ -67,6 +67,7 @@ ip = "0.0.0.0" bindings = [ { name = "DAP_AGGREGATE_STORE", class_name = "AggregateStore" }, { name = "DAP_TEST_STATE_CLEANER", class_name = "TestStateCleaner" }, + { name = "AGGREGATION_JOB_STORE", class_name = "AggregationJobStore" }, ] @@ -128,6 +129,7 @@ public_key = "047dab625e0d269abcc28c611bebf5a60987ddf7e23df0e0aa343e5774ad81a1d0 bindings = [ { name = "DAP_AGGREGATE_STORE", class_name = "AggregateStore" }, { name = "DAP_TEST_STATE_CLEANER", class_name = "TestStateCleaner" }, + { name = "AGGREGATION_JOB_STORE", class_name = "AggregationJobStore" }, ] [[env.leader.kv_namespaces]] @@ -154,4 +156,5 @@ tag = "v1" new_classes = [ "AggregateStore", "GarbageCollector", + "AggregationJobStore", ] diff --git a/crates/daphne-worker/src/aggregator/roles/helper.rs b/crates/daphne-worker/src/aggregator/roles/helper.rs index c6aec453c..42d836d2a 100644 --- a/crates/daphne-worker/src/aggregator/roles/helper.rs +++ b/crates/daphne-worker/src/aggregator/roles/helper.rs @@ -2,6 +2,43 @@ // SPDX-License-Identifier: BSD-3-Clause use crate::aggregator::App; -use daphne::roles::DapHelper; +use daphne::{ + error::DapAbort, + fatal_error, + messages::{request::AggregationJobRequestHash, AggregationJobId, TaskId}, + roles::DapHelper, + DapError, DapVersion, +}; +use daphne_service_utils::durable_requests::bindings::aggregation_job_store; +use std::borrow::Cow; -impl DapHelper for App {} +#[axum::async_trait] +impl DapHelper for App { + async fn assert_agg_job_is_immutable( + &self, + id: AggregationJobId, + version: DapVersion, + task_id: &TaskId, + req: &AggregationJobRequestHash, + ) -> Result<(), DapError> { + let response = self + .durable() + .with_retry() + .request(aggregation_job_store::Command::NewJob, (version, task_id)) + .encode(&aggregation_job_store::NewJobRequest { + id, + agg_job_hash: Cow::Borrowed(req.get()), + }) + .send::() + .await + .map_err(|e| fatal_error!(err = ?e, "failed to store aggregation job hash"))?; + + match response { + aggregation_job_store::NewJobResponse::Ok => Ok(()), + aggregation_job_store::NewJobResponse::IllegalJobParameters => Err( + DapAbort::BadRequest("aggregation job replay changes parameters".to_string()) + .into(), + ), + } + } +} diff --git a/crates/daphne-worker/src/aggregator/router/extractor.rs b/crates/daphne-worker/src/aggregator/router/extractor.rs index 5b9762f62..d065dfcd6 100644 --- a/crates/daphne-worker/src/aggregator/router/extractor.rs +++ b/crates/daphne-worker/src/aggregator/router/extractor.rs @@ -1,8 +1,6 @@ // Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause -use std::io::Cursor; - use axum::{ async_trait, body::Bytes, @@ -13,7 +11,7 @@ use daphne::{ error::DapAbort, fatal_error, messages::{ - request::{CollectionPollReq, RequestBody}, + request::{CollectionPollReq, HashedAggregationJobReq, RequestBody}, taskprov::TaskprovAdvertisement, AggregateShareReq, AggregationJobInitReq, CollectionReq, Report, TaskId, }, @@ -41,11 +39,10 @@ macro_rules! impl_decode_from_dap_http_body { bytes: Bytes, meta: &DapRequestMeta, ) -> Result { - let mut cursor = Cursor::new(bytes.as_ref()); // Check that media type matches. meta.get_checked_media_type(DapMediaType::$type)?; // Decode the body - $type::decode_with_param(&meta.version, &mut cursor) + $type::get_decoded_with_param(&meta.version, bytes.as_ref()) .map_err(|e| DapAbort::from_codec_error(e, meta.task_id)) } } @@ -60,6 +57,16 @@ impl_decode_from_dap_http_body!( CollectionReq, ); +impl DecodeFromDapHttpBody for HashedAggregationJobReq { + fn decode_from_http_body(bytes: Bytes, meta: &DapRequestMeta) -> Result { + // Check that media type matches. + meta.get_checked_media_type(DapMediaType::AggregationJobInitReq)?; + // Decode the body + HashedAggregationJobReq::get_decoded_with_param(&meta.version, bytes.as_ref()) + .map_err(|e| DapAbort::from_codec_error(e, meta.task_id)) + } +} + /// Using `()` ignores the body of a request. impl DecodeFromDapHttpBody for CollectionPollReq { fn decode_from_http_body(_bytes: Bytes, _meta: &DapRequestMeta) -> Result { diff --git a/crates/daphne-worker/src/aggregator/router/helper.rs b/crates/daphne-worker/src/aggregator/router/helper.rs index 6d4fd7512..98556c2e1 100644 --- a/crates/daphne-worker/src/aggregator/router/helper.rs +++ b/crates/daphne-worker/src/aggregator/router/helper.rs @@ -8,7 +8,7 @@ use axum::{ routing::{post, put}, }; use daphne::{ - messages::{AggregateShareReq, AggregationJobInitReq}, + messages::{request::HashedAggregationJobReq, AggregateShareReq}, roles::{helper, DapHelper}, }; use http::StatusCode; @@ -39,7 +39,7 @@ pub(super) fn add_helper_routes(router: super::Router) -> super::Router>, - DapRequestExtractor(req): DapRequestExtractor, + DapRequestExtractor(req): DapRequestExtractor, ) -> AxumDapResponse { let now = worker::Date::now(); diff --git a/crates/daphne-worker/src/durable/aggregation_job_store.rs b/crates/daphne-worker/src/durable/aggregation_job_store.rs new file mode 100644 index 000000000..d2a0d3e54 --- /dev/null +++ b/crates/daphne-worker/src/durable/aggregation_job_store.rs @@ -0,0 +1,112 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::{req_parse, GcDurableObject}; +use crate::int_err; +use daphne::messages::AggregationJobId; +use daphne_service_utils::durable_requests::bindings::{ + aggregation_job_store::{self, NewJobResponse}, + DurableMethod, +}; +use std::{collections::HashSet, sync::OnceLock, time::Duration}; +use worker::{js_sys::Uint8Array, Request, Response}; + +super::mk_durable_object! { + struct AggregationJobStore { + state: State, + env: Env, + seen_agg_job_ids: Option>, + } +} + +const SEEN_AGG_JOB_IDS_KEY: &str = "agg-job-ids"; + +impl GcDurableObject for AggregationJobStore { + type DurableMethod = aggregation_job_store::Command; + + fn with_state_and_env(state: worker::State, env: worker::Env) -> Self { + Self { + state, + env, + seen_agg_job_ids: None, + } + } + + async fn handle(&mut self, mut req: Request) -> worker::Result { + match Self::DurableMethod::try_from_uri(&req.path()) { + Some(aggregation_job_store::Command::NewJob) => { + let aggregation_job_store::NewJobRequest { id, agg_job_hash } = + req_parse(&mut req).await?; + + let key = &id.to_string(); + let response = match self.get::>(key).await? { + Some(hash) if hash == *agg_job_hash => NewJobResponse::Ok, + Some(_) => NewJobResponse::IllegalJobParameters, + None => { + self.state + .storage() + .put_raw(key, Uint8Array::from(agg_job_hash.as_ref())) + .await?; + let seen_agg_job_ids = self.load_seen_agg_job_ids().await?; + seen_agg_job_ids.insert(id); + self.store_seen_agg_job_ids().await?; + NewJobResponse::Ok + } + }; + + Response::from_json(&response) + } + Some(aggregation_job_store::Command::ListJobIds) => { + Response::from_json(&self.load_seen_agg_job_ids().await?) + } + None => Err(int_err(format!( + "AggregationJobStore: unexpected request: method={:?}; path={:?}", + req.method(), + req.path() + ))), + } + } + + fn should_cleanup_at(&self) -> Option { + const VAR_NAME: &str = "DO_AGGREGATION_JOB_STORE_GC_AFTER_SECS"; + static SELF_DELETE_AFTER: OnceLock = OnceLock::new(); + + let duration = SELF_DELETE_AFTER.get_or_init(|| { + Duration::from_secs( + self.env + .var(VAR_NAME) + .map(|v| { + v.to_string().parse().unwrap_or_else(|e| { + panic!("{VAR_NAME} could not be parsed as a number of seconds: {e}") + }) + }) + .unwrap_or(60 * 60 * 24 * 7), // one week + ) + }); + + Some(worker::ScheduledTime::from(*duration)) + } +} + +impl AggregationJobStore { + async fn load_seen_agg_job_ids(&mut self) -> worker::Result<&mut HashSet> { + let seen_agg_job_ids = if let Some(seen_agg_job_ids) = self.seen_agg_job_ids.take() { + seen_agg_job_ids + } else { + self.get_or_default(SEEN_AGG_JOB_IDS_KEY).await? + }; + + self.seen_agg_job_ids = Some(seen_agg_job_ids); + + Ok(self.seen_agg_job_ids.as_mut().unwrap()) + } + + async fn store_seen_agg_job_ids(&mut self) -> worker::Result<()> { + self.put( + SEEN_AGG_JOB_IDS_KEY, + self.seen_agg_job_ids.as_ref().unwrap(), + ) + .await?; + Ok(()) + } +} diff --git a/crates/daphne-worker/src/durable/mod.rs b/crates/daphne-worker/src/durable/mod.rs index 1c2fdd2c4..b719e8344 100644 --- a/crates/daphne-worker/src/durable/mod.rs +++ b/crates/daphne-worker/src/durable/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause //! This module defines the durable object implementations needed to run the DAP service. It @@ -20,6 +20,7 @@ //! this module as well as the [`instantiate_durable_object`] macro, respectively. pub(crate) mod aggregate_store; +pub(crate) mod aggregation_job_store; #[cfg(feature = "test-utils")] pub(crate) mod test_state_cleaner; @@ -33,6 +34,7 @@ use tracing::info_span; use worker::{Env, Error, Request, Response, Result, ScheduledTime, State}; pub use aggregate_store::AggregateStore; +pub use aggregation_job_store::AggregationJobStore; const ERR_NO_VALUE: &str = "No such value in storage."; @@ -115,7 +117,7 @@ macro_rules! mk_durable_object { } #[doc(hidden)] - pub async fn alarm(&mut self) -> Result { + pub async fn alarm(&mut self) -> ::worker::Result { ::tracing::trace!( instance = self.state.id().to_string(), "{}: alarm triggered, deleting...", @@ -137,6 +139,7 @@ macro_rules! mk_durable_object { &self.env } + #[allow(dead_code)] async fn get(&self, key: &str) -> ::worker::Result> where T: ::serde::de::DeserializeOwned, @@ -144,6 +147,7 @@ macro_rules! mk_durable_object { $crate::durable::state_get(&self.state, key).await } + #[allow(dead_code)] async fn get_or_default(&self, key: &str) -> ::worker::Result where T: ::serde::de::DeserializeOwned + std::default::Default, @@ -151,8 +155,16 @@ macro_rules! mk_durable_object { $crate::durable::state_get_or_default(&self.state, key).await } - #[expect(dead_code)] - async fn set_if_not_exists(&self, key: &str, val: &T) -> ::worker::Result> + #[allow(dead_code)] + async fn put(&self, key: &str, val: &T) -> ::worker::Result<()> + where + T: ::serde::Serialize, + { + self.state.storage().put(key, val).await + } + + #[allow(dead_code)] + async fn put_if_not_exists(&self, key: &str, val: &T) -> ::worker::Result> where T: ::serde::de::DeserializeOwned + ::serde::Serialize, { diff --git a/crates/daphne-worker/src/durable/test_state_cleaner.rs b/crates/daphne-worker/src/durable/test_state_cleaner.rs index 93f545470..e2b065116 100644 --- a/crates/daphne-worker/src/durable/test_state_cleaner.rs +++ b/crates/daphne-worker/src/durable/test_state_cleaner.rs @@ -43,14 +43,6 @@ impl TestStateCleaner { Some(bindings::TestStateCleaner::Put) => { let durable_ref: DurableReference = serde_json::from_slice(&req.bytes().await?).unwrap(); - match durable_ref.binding.as_ref() { - bindings::AggregateStore::BINDING => (), - s => { - let message = format!("GarbageCollector: unrecognized binding: {s}"); - console_error!("{}", message); - return Err(int_err(message)); - } - }; let queued = DurableOrdered::new_roughly_ordered(durable_ref, "object"); queued.put(&self.state).await?; diff --git a/crates/daphne-worker/src/storage/mod.rs b/crates/daphne-worker/src/storage/mod.rs index 99bf5289a..87c5951e7 100644 --- a/crates/daphne-worker/src/storage/mod.rs +++ b/crates/daphne-worker/src/storage/mod.rs @@ -35,7 +35,6 @@ impl<'h> Do<'h> { Self { env, retry: false } } - #[expect(dead_code)] pub fn with_retry(self) -> Self { Self { retry: true, diff --git a/crates/daphne/src/messages/mod.rs b/crates/daphne/src/messages/mod.rs index a0e2abbf7..959a96c1e 100644 --- a/crates/daphne/src/messages/mod.rs +++ b/crates/daphne/src/messages/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause //! Messages in the DAP protocol. @@ -585,6 +585,7 @@ impl ParameterizedDecode for PrepareInit { /// Aggregate initialization request. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] pub struct AggregationJobInitReq { pub agg_param: Vec, pub part_batch_sel: PartialBatchSelector, diff --git a/crates/daphne/src/messages/request.rs b/crates/daphne/src/messages/request.rs index a1135ffaa..aeeca629b 100644 --- a/crates/daphne/src/messages/request.rs +++ b/crates/daphne/src/messages/request.rs @@ -8,6 +8,7 @@ use super::{ CollectionJobId, CollectionReq, Report, }; use crate::{constants::DapMediaType, error::DapAbort, messages::TaskId, DapVersion}; +use prio::codec::{ParameterizedDecode, ParameterizedEncode}; pub trait RequestBody { type ResourceId; @@ -24,15 +25,79 @@ macro_rules! impl_req_body { }; } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] +pub struct AggregationJobRequestHash(Vec); + +impl AggregationJobRequestHash { + pub fn get(&self) -> &[u8] { + &self.0 + } + + fn hash(bytes: &[u8]) -> Self { + Self( + ring::digest::digest(&ring::digest::SHA256, bytes) + .as_ref() + .to_vec(), + ) + } +} + +pub struct HashedAggregationJobReq { + pub request: AggregationJobInitReq, + pub hash: AggregationJobRequestHash, +} + +impl HashedAggregationJobReq { + #[cfg(any(test, feature = "test-utils"))] + pub fn from_aggregation_req(version: DapVersion, request: AggregationJobInitReq) -> Self { + let mut buf = Vec::new(); + request.encode_with_param(&version, &mut buf).unwrap(); + Self { + request, + hash: AggregationJobRequestHash::hash(&buf), + } + } +} + +impl ParameterizedEncode for HashedAggregationJobReq { + fn encode_with_param( + &self, + encoding_parameter: &DapVersion, + bytes: &mut Vec, + ) -> Result<(), prio::codec::CodecError> { + self.request.encode_with_param(encoding_parameter, bytes) + } +} + +impl ParameterizedDecode for HashedAggregationJobReq { + fn decode_with_param( + decoding_parameter: &DapVersion, + bytes: &mut std::io::Cursor<&[u8]>, + ) -> Result { + let start = usize::try_from(bytes.position()) + .map_err(|e| prio::codec::CodecError::Other(Box::new(e)))?; + let request = AggregationJobInitReq::decode_with_param(decoding_parameter, bytes)?; + let end = usize::try_from(bytes.position()) + .map_err(|e| prio::codec::CodecError::Other(Box::new(e)))?; + + Ok(Self { + request, + hash: AggregationJobRequestHash::hash(&bytes.get_ref()[start..end]), + }) + } +} + impl_req_body! { -// body type | id type -// --------------------- | ---------------- - Report | () - AggregationJobInitReq | AggregationJobId - AggregateShareReq | () - CollectionReq | CollectionJobId - CollectionPollReq | CollectionJobId - () | () + // body type | id type + // --------------------| ---------------- + Report | () + AggregationJobInitReq | AggregationJobId + HashedAggregationJobReq | AggregationJobId + AggregateShareReq | () + CollectionReq | CollectionJobId + CollectionPollReq | CollectionJobId + () | () } /// Fields common to all DAP requests. @@ -74,6 +139,20 @@ pub struct DapRequest { pub payload: B, } +impl DapRequest { + pub fn map(self, mapper: F) -> DapRequest + where + F: FnOnce(B) -> O, + O: RequestBody, + { + DapRequest { + meta: self.meta, + resource_id: self.resource_id, + payload: mapper(self.payload), + } + } +} + impl AsRef for DapRequest { fn as_ref(&self) -> &DapRequestMeta { &self.meta diff --git a/crates/daphne/src/roles/helper/handle_agg_job.rs b/crates/daphne/src/roles/helper/handle_agg_job.rs index 58d852b25..412e01c7d 100644 --- a/crates/daphne/src/roles/helper/handle_agg_job.rs +++ b/crates/daphne/src/roles/helper/handle_agg_job.rs @@ -1,12 +1,12 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use super::{check_part_batch, DapHelper}; use crate::{ error::DapAbort, messages::{ - AggregationJobInitReq, AggregationJobResp, PartialBatchSelector, ReportError, TaskId, - TransitionVar, + request::HashedAggregationJobReq, AggregationJobInitReq, AggregationJobResp, + PartialBatchSelector, ReportError, TaskId, TransitionVar, }, metrics::ReportStatus, protocol::aggregator::ReportProcessedStatus, @@ -21,7 +21,11 @@ pub struct HandleAggJob { } /// The initial state, the aggregation request has been received. -pub struct Init(DapRequest); +pub struct Init(DapRequest); + +/// The aggregation job is legal. Which means that it's either new or it's parameters haven't +/// changed since the last time we've received it. +pub struct LegalAggregationJobReq(DapRequest); /// The task configuration associated with the incoming request has been resolved. pub struct WithTaskConfig { @@ -61,17 +65,37 @@ macro_rules! impl_from { impl_from!(Init, WithTaskConfig, InitializedReports); -pub fn start(request: DapRequest) -> HandleAggJob { +pub fn start(request: DapRequest) -> HandleAggJob { HandleAggJob::new(request) } impl HandleAggJob { - pub fn new(request: DapRequest) -> Self { + pub fn new(request: DapRequest) -> Self { Self { state: Init(request), } } + pub async fn check_aggregation_job_legality( + self, + aggregator: &A, + ) -> Result, DapError> { + let Self { state: Init(req) } = self; + aggregator + .assert_agg_job_is_immutable( + req.resource_id, + req.version, + &req.task_id, + &req.payload.hash, + ) + .await?; + Ok(HandleAggJob { + state: LegalAggregationJobReq(req.map(|r| r.request)), + }) + } +} + +impl HandleAggJob { /// Resolve the task config in the default way. pub async fn resolve_task_config( self, @@ -87,7 +111,7 @@ impl HandleAggJob { task_config: DapTaskConfig, ) -> Result, DapError> { let Self { - state: Init(request), + state: LegalAggregationJobReq(request), } = self; check_part_batch( diff --git a/crates/daphne/src/roles/helper/mod.rs b/crates/daphne/src/roles/helper/mod.rs index 128728905..052a326d1 100644 --- a/crates/daphne/src/roles/helper/mod.rs +++ b/crates/daphne/src/roles/helper/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause pub mod handle_agg_job; @@ -11,21 +11,34 @@ use crate::{ constants::DapMediaType, error::DapAbort, messages::{ - constant_time_eq, AggregateShare, AggregateShareReq, AggregationJobInitReq, - PartialBatchSelector, TaskId, + constant_time_eq, + request::{AggregationJobRequestHash, HashedAggregationJobReq}, + AggregateShare, AggregateShareReq, AggregationJobId, PartialBatchSelector, TaskId, }, metrics::{DaphneRequestType, ReportStatus}, protocol::aggregator::ReplayProtection, - DapAggregationParam, DapError, DapRequest, DapResponse, DapTaskConfig, + DapAggregationParam, DapError, DapRequest, DapResponse, DapTaskConfig, DapVersion, }; /// DAP Helper functionality. #[async_trait] -pub trait DapHelper: DapAggregator {} +pub trait DapHelper: DapAggregator { + /// Asserts that either: + /// - this the first time we see this aggregation job + /// - this aggregation job has been seen before and it hasn't changed since the last time we + /// saw it. + async fn assert_agg_job_is_immutable( + &self, + id: AggregationJobId, + version: DapVersion, + task_id: &TaskId, + req: &AggregationJobRequestHash, + ) -> Result<(), DapError>; +} pub async fn handle_agg_job_init_req( aggregator: &A, - req: DapRequest, + req: DapRequest, replay_protection: ReplayProtection, ) -> Result { let metrics = aggregator.metrics(); @@ -34,6 +47,8 @@ pub async fn handle_agg_job_init_req( let version = req.version; let agg_job_resp = handle_agg_job::start(req) + .check_aggregation_job_legality(aggregator) + .await? .resolve_task_config(aggregator) .await? .initialize_reports(aggregator, replay_protection) diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index 0754a9a61..3a9deaf7d 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -136,10 +136,11 @@ mod test { constants::DapMediaType, hpke::{HpkeKemId, HpkeProvider, HpkeReceiverConfig}, messages::{ - request::RequestBody, AggregateShareReq, AggregationJobId, AggregationJobInitReq, - AggregationJobResp, BatchId, BatchSelector, Collection, CollectionJobId, CollectionReq, - Extension, HpkeCiphertext, Interval, PartialBatchSelector, Query, Report, ReportError, - TaskId, Time, TransitionVar, + request::{HashedAggregationJobReq, RequestBody}, + AggregateShareReq, AggregationJobId, AggregationJobInitReq, AggregationJobResp, + BatchId, BatchSelector, Collection, CollectionJobId, CollectionReq, Extension, + HpkeCiphertext, Interval, PartialBatchSelector, Query, Report, ReportError, TaskId, + Time, TransitionVar, }, roles::{leader::WorkItem, DapAggregator}, testing::InMemoryAggregator, @@ -582,13 +583,16 @@ mod test { &task_config, agg_job_id, DapMediaType::AggregationJobInitReq, - AggregationJobInitReq { - agg_param: Vec::default(), - part_batch_sel: PartialBatchSelector::LeaderSelectedByBatchId { - batch_id: BatchId(rng.gen()), + HashedAggregationJobReq::from_aggregation_req( + version, + AggregationJobInitReq { + agg_param: Vec::default(), + part_batch_sel: PartialBatchSelector::LeaderSelectedByBatchId { + batch_id: BatchId(rng.gen()), + }, + prep_inits: Vec::default(), }, - prep_inits: Vec::default(), - }, + ), ); assert_matches!( helper::handle_agg_job_init_req(&*t.helper, req, Default::default()) @@ -748,10 +752,14 @@ mod test { // Get AggregationJobResp and then extract the transition data from inside. let agg_job_resp = AggregationJobResp::get_decoded_with_param( &version, - &helper::handle_agg_job_init_req(&*t.helper, req, Default::default()) - .await - .unwrap() - .payload, + &helper::handle_agg_job_init_req( + &*t.helper, + req.map(|req| HashedAggregationJobReq::from_aggregation_req(version, req)), + Default::default(), + ) + .await + .unwrap() + .payload, ) .unwrap(); let transition = &agg_job_resp.transitions[0]; @@ -777,10 +785,14 @@ mod test { // Get AggregationJobResp and then extract the transition data from inside. let agg_job_resp = AggregationJobResp::get_decoded_with_param( &version, - &helper::handle_agg_job_init_req(&*t.helper, req, Default::default()) - .await - .unwrap() - .payload, + &helper::handle_agg_job_init_req( + &*t.helper, + req.map(|req| HashedAggregationJobReq::from_aggregation_req(version, req)), + Default::default(), + ) + .await + .unwrap() + .payload, ) .unwrap(); let transition = &agg_job_resp.transitions[0]; diff --git a/crates/daphne/src/testing/mod.rs b/crates/daphne/src/testing/mod.rs index 23b3d2025..f324bfda8 100644 --- a/crates/daphne/src/testing/mod.rs +++ b/crates/daphne/src/testing/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause //! Mock backend functionality to test DAP protocol. @@ -12,9 +12,9 @@ use crate::{ fatal_error, hpke::{HpkeConfig, HpkeKemId, HpkeProvider, HpkeReceiverConfig}, messages::{ - self, AggregationJobId, AggregationJobInitReq, AggregationJobResp, Base64Encode, BatchId, - BatchSelector, Collection, CollectionJobId, HpkeCiphertext, Interval, PartialBatchSelector, - Report, ReportId, TaskId, Time, + self, request::AggregationJobRequestHash, AggregationJobId, AggregationJobInitReq, + AggregationJobResp, Base64Encode, BatchId, BatchSelector, Collection, CollectionJobId, + HpkeCiphertext, Interval, PartialBatchSelector, Report, ReportId, TaskId, Time, }, metrics::{prometheus::DaphnePromMetrics, DaphneMetrics}, roles::{ @@ -34,7 +34,7 @@ use prio::codec::{ParameterizedDecode, ParameterizedEncode}; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use std::{ - collections::{HashMap, HashSet}, + collections::{hash_map::Entry, HashMap, HashSet}, hash::Hash, num::NonZeroUsize, ops::{DerefMut, Range}, @@ -516,6 +516,9 @@ pub struct InMemoryAggregator { // Leader: Reference to peer. Used to simulate HTTP requests from Leader to Helper, i.e., // implement `DapLeader::send_http_post()` for `InMemoryAggregator`. Not set by the Helper. peer: Option>, + + // Helper: aggregation jobs + processed_jobs: Mutex>, } impl DeepSizeOf for InMemoryAggregator { @@ -531,6 +534,7 @@ impl DeepSizeOf for InMemoryAggregator { audit_log: _, taskprov_vdaf_verify_key_init, peer, + processed_jobs, } = self; global_config.deep_size_of_children(context) + tasks.deep_size_of_children(context) @@ -540,6 +544,7 @@ impl DeepSizeOf for InMemoryAggregator { + collector_hpke_config.deep_size_of_children(context) + taskprov_vdaf_verify_key_init.deep_size_of_children(context) + peer.deep_size_of_children(context) + + processed_jobs.deep_size_of_children(context) } } @@ -563,6 +568,7 @@ impl InMemoryAggregator { audit_log: MockAuditLog::default(), taskprov_vdaf_verify_key_init, peer: None, + processed_jobs: Default::default(), } } @@ -586,6 +592,7 @@ impl InMemoryAggregator { audit_log: MockAuditLog::default(), taskprov_vdaf_verify_key_init, peer: peer.into(), + processed_jobs: Default::default(), } } @@ -866,7 +873,32 @@ impl DapAggregator for InMemoryAggregator { } #[async_trait] -impl DapHelper for InMemoryAggregator {} +impl DapHelper for InMemoryAggregator { + async fn assert_agg_job_is_immutable( + &self, + id: AggregationJobId, + _version: DapVersion, + _task_id: &TaskId, + req: &AggregationJobRequestHash, + ) -> Result<(), DapError> { + match self.processed_jobs.lock().unwrap().entry(id) { + Entry::Occupied(occupied_entry) => { + if occupied_entry.get() == req { + Ok(()) + } else { + Err(DapAbort::BadRequest( + "chaning aggregation job parameters is illegal".to_string(), + ) + .into()) + } + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(req.clone()); + Ok(()) + } + } + } +} #[async_trait] impl DapLeader for InMemoryAggregator {