Skip to content

Move the aggregator logic back into the worker#756

Merged
mendess merged 3 commits intomainfrom
mendess/async-helper
Jan 8, 2025
Merged

Move the aggregator logic back into the worker#756
mendess merged 3 commits intomainfrom
mendess/async-helper

Conversation

@mendess
Copy link
Collaborator

@mendess mendess commented Jan 6, 2025

In order to implement the async change in draft 13 we want to make use of worker queues, to simplify the codebase we elected to get rid of the storage proxy. In order to go from something like

+------------+     +---------------+     +---------------+
| nginx-like | --> | daphne-server | --> | storage-proxy |
+------------+     +---------------+     +---------------+

to something like

+---------------------+     +---------------+
| aggregator (worker) | --> | daphne-server |
+---------------------+     +---------------+
    |           ^
    |           |
    +-----------+
       queues

This PR copies code from daphne-server to daphne-worker under the directory src/aggregator in order to move the routing logic back into the worker. In the future PR the CPU intensive bits will be delegated to a new route in the daphne-server.

Because github can't provide a diff between existing files and brand new files I made a small script to generate the "real diff" to ease review of the code.

Differences in crates/daphne-worker/src/aggregator/metrics.rs
--- crates/daphne-server/src/metrics.rs
+++ crates/daphne-worker/src/aggregator/metrics.rs
@@ -1,11 +1,11 @@
-// Copyright (c) 2022 Cloudflare, Inc. All rights reserved.
+// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
 // SPDX-License-Identifier: BSD-3-Clause
 
 //! Daphne-Worker metrics.
 
 use daphne::metrics::DaphneMetrics;
 
-pub trait DaphneServiceMetrics: DaphneMetrics {
+pub trait DaphneServiceMetrics {
     fn abort_count_inc(&self, label: &str);
     fn count_http_status_code(&self, status_code: u16);
     fn daphne(&self) -> &dyn DaphneMetrics;
@@ -19,7 +19,6 @@
     TlsClientAuth,
 }
 
-#[cfg(any(feature = "prometheus", test))]
 mod prometheus {
     use super::DaphneServiceMetrics;
     use daphne::{
@@ -137,5 +136,4 @@
     }
 }
 
-#[cfg(any(feature = "prometheus", test))]
 pub use prometheus::DaphnePromServiceMetrics;
Differences in crates/daphne-worker/src/aggregator/mod.rs
--- crates/daphne-server/src/./lib.rs
+++ crates/daphne-worker/src/aggregator/mod.rs
@@ -1,91 +1,36 @@
 // Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
 // SPDX-License-Identifier: BSD-3-Clause
 
-use std::sync::Arc;
+mod config;
+mod metrics;
+mod roles;
+mod router;
 
+use crate::storage::{kv, Do, Kv};
 use config::{DaphneServiceConfig, PeerBearerToken};
 use daphne::{
     audit_log::{AuditLog, NoopAuditLog},
     constants::DapRole,
     fatal_error,
-    messages::{Base64Encode, TaskId},
-    roles::{leader::in_memory_leader::InMemoryLeaderState, DapAggregator},
+    messages::TaskId,
+    roles::{leader::in_memory_leader::InMemoryLeaderState, DapAggregator as _},
     DapError,
 };
 use daphne_service_utils::bearer_token::BearerToken;
 use either::Either::{self, Left, Right};
-use futures::lock::Mutex;
 use metrics::DaphneServiceMetrics;
 use roles::BearerTokens;
-use serde::{Deserialize, Serialize};
-use storage_proxy_connection::{kv, Do, Kv};
-use url::Url;
+use router::DaphneService;
+use std::sync::{Arc, LazyLock, Mutex};
+use worker::send::SendWrapper;
 
-pub mod config;
-pub mod metrics;
-mod roles;
-pub mod router;
-mod storage_proxy_connection;
+pub use router::handle_dap_request;
 
-/// Entrypoint to the server implementation. This struct implements
-/// [`DapLeader`](daphne::roles::DapLeader) and [`DapHelper`](daphne::roles::DapHelper) and can be
-/// passed to the router.
-///
-/// It depends on a cloudflare worker to do it's storage using durable objects.
-///
-/// It can be constructed from:
-/// - a `url` that points to a cloudflare worker which serves as proxy for the storage
-///   implementation.
-/// - an implementation of [`DaphneServiceMetrics`].
-/// - a [`DaphneServiceConfig`].
-///
-/// # Examples
-/// ```
-/// use std::num::NonZeroUsize;
-/// use url::Url;
-/// use daphne::{DapGlobalConfig, constants::DapAggregatorRole, hpke::HpkeKemId, DapVersion};
-/// use daphne_server::{
-///     App,
-///     router,
-///     StorageProxyConfig,
-///     metrics::DaphnePromServiceMetrics,
-///     config::DaphneServiceConfig,
-/// };
-///
-/// let storage_proxy_settings = StorageProxyConfig {
-///     url: Url::parse("http://example.com").unwrap(),
-///     auth_token: "some-token".into(),
-/// };
-/// let registry = prometheus::Registry::new();
-/// let daphne_service_metrics = DaphnePromServiceMetrics::register(&registry).unwrap();
-/// let global = DapGlobalConfig {
-///     max_batch_duration: 360_00,
-///     min_batch_interval_start: 259_200,
-///     max_batch_interval_end: 259_200,
-///     supported_hpke_kems: vec![HpkeKemId::X25519HkdfSha256],
-///     default_num_agg_span_shards: NonZeroUsize::new(2).unwrap(),
-/// };
-/// let service_config = DaphneServiceConfig {
-///     role: DapAggregatorRole::Helper,
-///     global,
-///     base_url: None,
-///     taskprov: None,
-///     default_version: DapVersion::Draft09,
-///     report_storage_epoch_duration: 300,
-///     report_storage_max_future_time_skew: 300,
-///     signing_key: None,
-/// };
-/// let app = App::new(storage_proxy_settings, daphne_service_metrics, service_config)?;
-///
-/// let router = router::new(DapAggregatorRole::Helper, app);
-///
-/// # Ok::<(), daphne::DapError>(())
-/// ```
 pub struct App {
-    storage_proxy_config: StorageProxyConfig,
     http: reqwest::Client,
+    env: SendWrapper<worker::Env>,
     kv_state: kv::State,
-    metrics: Box<dyn DaphneServiceMetrics>,
+    metrics: Box<dyn DaphneServiceMetrics + Send + Sync>,
     service_config: DaphneServiceConfig,
     audit_log: Box<dyn AuditLog + Send + Sync>,
 
@@ -95,14 +40,10 @@
     test_leader_state: Arc<Mutex<InMemoryLeaderState>>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
-pub struct StorageProxyConfig {
-    pub url: Url,
-    pub auth_token: BearerToken,
-}
+static_assertions::assert_impl_all!(App: Send, Sync);
 
-#[axum::async_trait]
-impl router::DaphneService for App {
+#[async_trait::async_trait]
+impl DaphneService for App {
     fn server_metrics(&self) -> &dyn DaphneServiceMetrics {
         &*self.metrics
     }
@@ -158,7 +99,7 @@
         {
             Ok(())
         } else {
-            reject(format_args!("with task_id {}", task_id.to_base64url()))
+            reject(format_args!("with task_id {task_id}"))
         }
     }
 
@@ -180,45 +121,35 @@
 
 impl App {
     /// Create a new configured app. See [`App`] for details.
-    pub fn new<M>(
-        storage_proxy_config: StorageProxyConfig,
-        daphne_service_metrics: M,
-        service_config: DaphneServiceConfig,
-    ) -> Result<Self, DapError>
-    where
-        M: DaphneServiceMetrics + 'static,
-    {
+    pub fn new(
+        env: worker::Env,
+        registry: &prometheus::Registry,
+        audit_log: impl Into<Option<Box<dyn AuditLog + Send + Sync>>>,
+    ) -> Result<Self, DapError> {
+        static PERSISTENT_ENOUGH_STATE: LazyLock<Arc<Mutex<InMemoryLeaderState>>> =
+            LazyLock::new(Default::default);
+        let metrics = metrics::DaphnePromServiceMetrics::register(registry)?;
+        let service_config = config::load_config_from_env(&env)?;
         Ok(Self {
-            storage_proxy_config,
             http: reqwest::Client::new(),
+            env: SendWrapper(env),
             kv_state: Default::default(),
-            metrics: Box::new(daphne_service_metrics),
-            audit_log: Box::new(NoopAuditLog),
+            metrics: Box::new(metrics),
+            audit_log: audit_log.into().unwrap_or_else(|| Box::new(NoopAuditLog)),
             service_config,
-            test_leader_state: Default::default(),
+            test_leader_state: PERSISTENT_ENOUGH_STATE.clone(),
         })
     }
 
-    pub fn set_audit_log<A>(&mut self, audit_log: A)
-    where
-        A: AuditLog + Send + Sync + 'static,
-    {
-        self.audit_log = Box::new(audit_log);
+    fn durable(&self) -> Do<'_> {
+        Do::new(&self.env)
     }
 
-    pub(crate) fn durable(&self) -> Do<'_> {
-        Do::new(&self.storage_proxy_config, &self.http)
+    fn kv(&self) -> Kv<'_> {
+        Kv::new(&self.env, &self.kv_state)
     }
 
-    pub(crate) fn kv(&self) -> Kv<'_> {
-        Kv::new(&self.storage_proxy_config, &self.http, &self.kv_state)
-    }
-
-    pub(crate) fn bearer_tokens(&self) -> BearerTokens<'_> {
-        BearerTokens::from(Kv::new(
-            &self.storage_proxy_config,
-            &self.http,
-            &self.kv_state,
-        ))
+    fn bearer_tokens(&self) -> BearerTokens<'_> {
+        BearerTokens::from(Kv::new(&self.env, &self.kv_state))
     }
 }
Differences in crates/daphne-worker/src/aggregator/config.rs
--- crates/daphne-server/src/config.rs
+++ crates/daphne-worker/src/aggregator/config.rs
@@ -2,9 +2,8 @@
 // SPDX-License-Identifier: BSD-3-Clause
 
 use daphne::{
-    constants::DapAggregatorRole,
-    hpke::{HpkeConfig, HpkeReceiverConfig},
-    DapGlobalConfig, DapVersion,
+    constants::DapAggregatorRole, fatal_error, hpke::HpkeConfig, DapError, DapGlobalConfig,
+    DapVersion,
 };
 use daphne_service_utils::bearer_token::BearerToken;
 use p256::ecdsa::SigningKey;
@@ -15,7 +14,6 @@
 #[derive(Serialize, Deserialize, Debug, Clone)]
 pub struct TaskprovConfig {
     /// HPKE collector configuration for all taskprov tasks.
-    #[serde(with = "from_raw_string")]
     pub hpke_collector_config: HpkeConfig,
 
     /// VDAF verify key init secret, used to generate the VDAF verification key for a taskprov task.
@@ -26,7 +24,6 @@
     pub peer_auth: PeerBearerToken,
 
     /// Bearer token used when trying to communicate with an aggregator using taskprov.
-    #[serde(default)]
     pub self_bearer_token: Option<BearerToken>,
 }
 
@@ -40,8 +37,6 @@
     Collector { expected_token: BearerToken },
 }
 
-pub type HpkeRecieverConfigList = Vec<HpkeReceiverConfig>;
-
 /// Daphne service configuration, including long-lived parameters used across DAP tasks.
 #[derive(Serialize, Deserialize, Debug, Clone)]
 pub struct DaphneServiceConfig {
@@ -54,12 +49,10 @@
 
     /// draft-dcook-ppm-dap-interop-test-design: Base URL of the Aggregator (unversioned). If set,
     /// this field is used for endpoint configuration for interop testing.
-    #[serde(default)]
     pub base_url: Option<Url>,
 
     /// draft-wang-ppm-dap-taskprov: Long-lived parameters for the taskprov extension. If not set,
     /// then taskprov will be disabled.
-    #[serde(default)]
     pub taskprov: Option<TaskprovConfig>,
 
     /// Default DAP version to use if not specified by the API URL
@@ -211,33 +204,101 @@
     Dev,
 }
 
-mod from_raw_string {
-    //! This is used to deserialize secrets, which are stored in as raw strings. As such they need
-    //! a custom deserializer.
-
-    use serde::{
-        de::{self, DeserializeOwned},
-        ser, Deserialize, Deserializer, Serialize, Serializer,
-    };
-
-    pub fn serialize<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
-    where
-        S: Serializer,
-        S::Error: ser::Error,
-        T: Serialize,
-    {
-        serde_json::to_string(value)
-            .map_err(<S::Error as ser::Error>::custom)
-            .and_then(|s| serializer.serialize_str(&s))
-    }
-
-    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
-    where
-        D: Deserializer<'de>,
-        D::Error: de::Error,
-        T: DeserializeOwned,
-    {
-        let s = String::deserialize(deserializer)?;
-        serde_json::from_str(&s).map_err(<D::Error as de::Error>::custom)
-    }
-}
+pub fn load_config_from_env(env: &worker::Env) -> Result<DaphneServiceConfig, DapError> {
+    const SERVICE_CONFIG: &str = "SERVICE_CONFIG";
+    const SIGNING_KEY: &str = "SIGNING_KEY";
+
+    let mut config = env
+        .object_var::<DaphneServiceConfig>(SERVICE_CONFIG)
+        .map_err(|e| fatal_error!(err = ?e, "failed to load SERVICE_CONFIG variable"))?;
+
+    if config.taskprov.is_some() {
+        tracing::warn!("taskprov secrets are defined in plain text. Prefer using wrangler secrets");
+    } else if matches!(env.var(taskprov_secrets::ENABLED), Ok(s) if s.to_string() == "true") {
+        config.taskprov = Some(taskprov_secrets::load(env)?);
+    }
+
+    if config.signing_key.is_some() {
+        tracing::warn!("signing key is defined in plain text. Prefer using wrangler secrets");
+    } else {
+        config.signing_key = env
+            .var(SIGNING_KEY)
+            .ok()
+            .map(|s| p256::SecretKey::from_sec1_pem(&s.to_string()).map(SigningKey::from))
+            .transpose()
+            .map_err(|e| fatal_error!(err = ?e, "failed to deserialize SIGNING_KEY"))?
+    }
+    Ok(config)
+}
+
+mod taskprov_secrets {
+    use super::{PeerBearerToken, TaskprovConfig};
+    use daphne::{fatal_error, DapError};
+    use daphne_service_utils::bearer_token::BearerToken;
+
+    pub const ENABLED: &str = constcat::concat!(TASKPROV_SECRETS, "_", "ENABLED");
+
+    const TASKPROV_SECRETS: &str = "TASKPROV_SECRETS";
+    const VDAF_VERIFY_KEY_INIT: &str =
+        constcat::concat!(TASKPROV_SECRETS, "_", "VDAF_VERIFY_KEY_INIT");
+    const PEER_AUTH_LEADER_EXPECTED_TOKEN: &str =
+        constcat::concat!(TASKPROV_SECRETS, "_", "PEER_AUTH_EXPECT_LEADER_TOKEN");
+    const PEER_AUTH_COLLECTOR_EXPECTED_TOKEN: &str =
+        constcat::concat!(TASKPROV_SECRETS, "_", "PEER_AUTH_EXPECT_COLLECTOR_TOKEN");
+    const SELF_BEARER_TOKEN: &str = constcat::concat!(TASKPROV_SECRETS, "_", "SELF_BEARER_TOKEN");
+    const TASKPROV_HPKE_COLLECTOR_CONFIG: &str = "TASKPROV_HPKE_COLLECTOR_CONFIG";
+
+    pub fn load(env: &worker::Env) -> Result<TaskprovConfig, DapError> {
+        Ok(super::TaskprovConfig {
+            hpke_collector_config: env.object_var(TASKPROV_HPKE_COLLECTOR_CONFIG).map_err(
+                |e| fatal_error!(err = ?e, "failed to load TASKPROV_HPKE_COLLECTOR_CONFIG"),
+            )?,
+            vdaf_verify_key_init: {
+                let key = VDAF_VERIFY_KEY_INIT;
+                hex::decode(
+                    env.var(key)
+                        .map(|t| t.to_string())
+                        .map_err(|e| fatal_error!(err = ?e, "failed to load {key}"))?,
+                )
+                .map_err(|e| fatal_error!(err = ?e, "invalid {key}"))?
+                .try_into()
+                .map_err(|e: Vec<_>| {
+                    fatal_error!(
+                        err = format!("{key} of invalid length. Got {} expected 32", e.len())
+                    )
+                })?
+            },
+            peer_auth: match (
+                env.var(PEER_AUTH_LEADER_EXPECTED_TOKEN),
+                env.var(PEER_AUTH_COLLECTOR_EXPECTED_TOKEN),
+            ) {
+                (Ok(_), Ok(_)) => {
+                    return Err(fatal_error!(
+                        err = format!(
+                            "{} and {} were defined simultaneously, this is not allowed",
+                            PEER_AUTH_LEADER_EXPECTED_TOKEN, PEER_AUTH_COLLECTOR_EXPECTED_TOKEN
+                        )
+                    ))
+                }
+                (Ok(leader), _) => PeerBearerToken::Leader {
+                    expected_token: leader.to_string().into(),
+                },
+                (_, Ok(collector)) => PeerBearerToken::Collector {
+                    expected_token: collector.to_string().into(),
+                },
+                (Err(e), _) => {
+                    return Err(fatal_error!(
+                        err = ?e,
+                        "failed to load {} or {}",
+                        PEER_AUTH_LEADER_EXPECTED_TOKEN,
+                        PEER_AUTH_COLLECTOR_EXPECTED_TOKEN
+                    ))
+                }
+            },
+            self_bearer_token: env
+                .var(SELF_BEARER_TOKEN)
+                .ok()
+                .map(|t| BearerToken::from(t.to_string())),
+        })
+    }
+}
Differences in crates/daphne-worker/src/aggregator/router/test_routes.rs
--- crates/daphne-server/src/router/test_routes.rs
+++ crates/daphne-worker/src/aggregator/router/test_routes.rs
@@ -20,7 +20,7 @@
 use daphne_service_utils::test_route_types::{InternalTestAddTask, InternalTestEndpointForTask};
 use serde::Deserialize;
 
-use crate::App;
+use super::App;
 
 use super::{AxumDapResponse, DaphneService};
 
@@ -38,7 +38,7 @@
 
     router
         .route("/internal/delete_all", post(delete_all))
-        .route("/internal/test/ready", post(check_storage_readyness))
+        .route("/internal/test/ready", post(StatusCode::OK))
         .route(
             "/internal/test/endpoint_for_task",
             post(endpoint_for_task_default),
@@ -57,14 +57,6 @@
             "/:version/internal/test/add_hpke_config",
             post(add_hpke_config),
         )
-}
-
-#[tracing::instrument(skip(app))]
-async fn check_storage_readyness(State(app): State<Arc<App>>) -> Response {
-    match app.storage_ready_check().await {
-        Ok(()) => StatusCode::OK.into_response(),
-        Err(e) => AxumDapResponse::new_error(e, &*app.metrics).into_response(),
-    }
 }
 
 #[tracing::instrument(skip(app))]
@@ -92,6 +84,7 @@
 }
 
 #[tracing::instrument(skip(app))]
+#[worker::send]
 async fn delete_all(State(app): State<Arc<App>>) -> impl IntoResponse {
     match app.internal_delete_all().await {
         Ok(()) => StatusCode::OK.into_response(),
@@ -149,6 +142,7 @@
 }
 
 #[tracing::instrument(skip(app, hpke))]
+#[worker::send]
 async fn add_hpke_config(
     State(app): State<Arc<App>>,
     Path(version): Path<DapVersion>,
@@ -165,6 +159,7 @@
 }
 
 #[tracing::instrument(skip(app, json))]
+#[worker::send]
 async fn add_hpke_config_default(
     State(app): State<Arc<App>>,
     json: Json<HpkeReceiverConfig>,
Differences in crates/daphne-worker/src/aggregator/router/extractor.rs
--- crates/daphne-server/src/router/extractor.rs
+++ crates/daphne-worker/src/aggregator/router/extractor.rs
@@ -24,7 +24,7 @@
 use prio::codec::ParameterizedDecode;
 use serde::Deserialize;
 
-use crate::metrics;
+use super::super::metrics;
 
 use super::{AxumDapResponse, DaphneService};
 
@@ -258,8 +258,8 @@
     async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
         let bearer_token = extract_header_as_str(req.headers(), http_headers::DAP_AUTH_TOKEN)
             .map(BearerToken::from);
-        let cf_tls_client_auth =
-            extract_header_as_str(req.headers(), "X-Client-Cert-Verified").map(ToString::to_string);
+
+        let cf_tls_client_auth = mtls_auth_from_request(&req);
 
         let request = UnauthenticatedDapRequestExtractor::from_request(req, state)
             .await?
@@ -298,21 +298,19 @@
         } else {
             false
         };
-        let mtls_authed = if let Some(verification_result) = cf_tls_client_auth {
+
+        // attempt to auth with mtls
+        if matches!(cf_tls_client_auth, Err(_) | Ok(true)) {
             state
                 .server_metrics()
                 .auth_method_inc(metrics::AuthMethod::TlsClientAuth);
-            // we always check if mtls succedded even if ...
-            if verification_result != "SUCCESS" {
-                return Err(auth_error(format!(
-                    "Invalid TLS certificate ({verification_result})"
-                )));
-            }
+        }
+
+        let mtls_authed = cf_tls_client_auth
+            // we always check if mtls succedded, even if ...
+            .map_err(auth_error)?
             // ... we only allow mtls auth for taskprov tasks
-            is_taskprov
-        } else {
-            false
-        };
+            && is_taskprov;
 
         if bearer_authed || mtls_authed {
             Ok(Self(request))
@@ -322,6 +320,44 @@
             ))
         }
     }
+}
+
+/// Check if there was an mtls auth attempt.
+///
+/// - Ok(false) means no certificate was presented
+/// - Ok(true) means a valid certificate was presented
+/// - Err(e) means an invalid certificate was presented
+pub(crate) fn mtls_auth_from_request<B>(req: &http::Request<B>) -> Result<bool, String> {
+    // The runtime gives us a cf_tls_client_auth whether the communication was secured by
+    // it or not, so if a certificate wasn't presented, treat it as if it weren't there.
+    // We only check for the validity of the certificate if it was present in the daphne server
+    // Literal "1" indicates that a certificate was presented.
+    #[cfg(test)]
+    let not_present = &test::ClientTlsAuthMock::do_not_present();
+    #[cfg(test)]
+    let cf = Some(
+        req.extensions()
+            .get::<test::ClientTlsAuthMock>()
+            .unwrap_or(not_present),
+    );
+
+    #[cfg(not(test))]
+    let cf = req
+        .extensions()
+        .get::<worker::Cf>()
+        .expect("cf object should always be present")
+        .tls_client_auth();
+
+    cf.filter(|auth| auth.cert_presented() == "1")
+        .map(|cert| {
+            let verified = cert.cert_verified();
+            match verified.as_str() {
+                "SUCCESS" => Ok(()),
+                _ => Err(format!("Invalid TLS certificate ({verified})")),
+            }
+        })
+        .transpose()
+        .map(|present| present.is_some())
 }
 
 fn extract_header_as_str<'s>(headers: &'s HeaderMap, header: &'static str) -> Option<&'s str> {
@@ -363,14 +399,39 @@
     };
     use tower::ServiceExt;
 
-    use crate::{
-        metrics::{DaphnePromServiceMetrics, DaphneServiceMetrics},
-        router::extractor::UnauthenticatedDapRequestExtractor,
+    use super::{
+        dap_sender::FROM_LEADER, metrics::DaphneServiceMetrics, resource_parsers,
+        DecodeFromDapHttpBody, UnauthenticatedDapRequestExtractor,
     };
-
-    use super::{dap_sender::FROM_LEADER, resource_parsers, DecodeFromDapHttpBody};
+    use crate::aggregator::metrics::DaphnePromServiceMetrics;
     use http::{header, StatusCode};
     use prio::codec::ParameterizedEncode;
+
+    /// We can't mock [`worker::Cf`] but we can mock a type that behaves similarly enough for us to
+    /// test the mtls logic.
+    #[derive(Debug, Clone)]
+    pub struct ClientTlsAuthMock(Option<&'static str>);
+
+    impl ClientTlsAuthMock {
+        pub fn present_success() -> Self {
+            Self(Some("SUCCESS"))
+        }
+        pub fn present_failure() -> Self {
+            Self(Some("FAILURE"))
+        }
+        pub fn do_not_present() -> Self {
+            Self(None)
+        }
+
+        #[allow(clippy::unused_self)]
+        pub fn cert_presented(&self) -> &'static str {
+            self.0.map_or("0", |_| "1")
+        }
+
+        pub fn cert_verified(&self) -> String {
+            self.0.map_or_else(|| "missing".into(), |v| v.to_string())
+        }
+    }
 
     const BEARER_TOKEN: &str = "test-token";
 
@@ -680,7 +741,7 @@
                         .compute_task_id(version)
                         .to_base64url()
                 ))
-                .header("X-Client-Cert-Verified", "SUCCESS")
+                .extension(ClientTlsAuthMock::present_success())
                 .header(
                     http_headers::DAP_TASKPROV,
                     taskprov_advertisement
@@ -702,7 +763,7 @@
             Request::builder()
                 .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
                 .header(http_headers::DAP_AUTH_TOKEN, "something incorrect")
-                .header("X-Client-Cert-Verified", "SUCCESS")
+                .extension(ClientTlsAuthMock::present_success())
                 .body(Body::empty())
                 .unwrap(),
         )
@@ -719,7 +780,7 @@
             Request::builder()
                 .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
                 .header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
-                .header("X-Client-Cert-Verified", "FAILED")
+                .extension(ClientTlsAuthMock::present_failure())
                 .body(Body::empty())
                 .unwrap(),
         )
Differences in crates/daphne-worker/src/aggregator/router/mod.rs
--- crates/daphne-server/src/router/mod.rs
+++ crates/daphne-worker/src/aggregator/router/mod.rs
@@ -14,7 +14,7 @@
     extract::{Request, State},
     http::{header::CONTENT_TYPE, HeaderValue, StatusCode},
     middleware::Next,
-    response::IntoResponse,
+    response::{IntoResponse, Response},
     Json,
 };
 use daphne::{
@@ -27,8 +27,10 @@
 use daphne_service_utils::bearer_token::BearerToken;
 use either::Either;
 
-use crate::{metrics::DaphneServiceMetrics, App};
+use super::{metrics::DaphneServiceMetrics, App};
 use extractor::{DapRequestExtractor, UnauthenticatedDapRequestExtractor};
+use tower::ServiceExt as _;
+use worker::HttpRequest;
 
 type Router<A> = axum::Router<Arc<A>>;
 
@@ -89,18 +91,18 @@
     }
 }
 
-pub fn new(role: DapAggregatorRole, aggregator: App) -> axum::Router<()> {
+pub async fn handle_dap_request(app: App, req: HttpRequest) -> Response {
     let router = axum::Router::new();
 
     let router = aggregator::add_aggregator_routes(router);
 
-    let router = match role {
+    let router = match app.service_config.role {
         DapAggregatorRole::Leader => leader::add_leader_routes(router),
         DapAggregatorRole::Helper => helper::add_helper_routes(router),
     };
 
     #[cfg(feature = "test-utils")]
-    let router = test_routes::add_test_routes(router, role);
+    let router = test_routes::add_test_routes(router, app.service_config.role);
 
     async fn request_metrics(
         State(app): State<Arc<App>>,
@@ -115,15 +117,19 @@
         resp
     }
 
-    let app = Arc::new(aggregator);
-    router
-        .with_state(app.clone())
+    let aggregator = Arc::new(app);
+    let Ok(response) = router
+        .with_state(aggregator.clone())
         .layer(
             tower::ServiceBuilder::new().layer(axum::middleware::from_fn_with_state(
-                app.clone(),
+                aggregator,
                 request_metrics,
             )),
         )
+        .oneshot(req)
+        .await;
+
+    response
 }
 
 struct AxumDapResponse(axum::response::Response);
@@ -160,7 +166,7 @@
     }
 
     pub fn new_error<E: Into<DapError>>(error: E, metrics: &dyn DaphneServiceMetrics) -> Self {
-        // Trigger abort if report errors reach this point.
+        // Trigger abort if reports errors reach this point.
         let error = match error.into() {
             DapError::ReportError(err) => DapAbort::report_rejected(err),
             DapError::Fatal(e) => Err(e),
Differences in crates/daphne-worker/src/aggregator/router/helper.rs
--- crates/daphne-server/src/router/helper.rs
+++ crates/daphne-worker/src/aggregator/router/helper.rs
@@ -13,11 +13,11 @@
 };
 use http::StatusCode;
 
-use crate::{roles::fetch_replay_protection_override, App};
-
 use super::{
-    extractor::dap_sender::FROM_LEADER, AxumDapResponse, DapRequestExtractor, DaphneService,
+    super::roles::fetch_replay_protection_override, extractor::dap_sender::FROM_LEADER, App,
+    AxumDapResponse, DapRequestExtractor, DaphneService,
 };
+use crate::elapsed;
 
 pub(super) fn add_helper_routes(router: super::Router<App>) -> super::Router<App> {
     router
@@ -36,11 +36,12 @@
         version = ?req.version,
     )
 )]
+#[worker::send]
 async fn agg_job(
     State(app): State<Arc<App>>,
     DapRequestExtractor(req): DapRequestExtractor<FROM_LEADER, AggregationJobInitReq>,
 ) -> AxumDapResponse {
-    let timer = std::time::Instant::now();
+    let now = worker::Date::now();
 
     let resp = helper::handle_agg_job_init_req(
         &*app,
@@ -49,7 +50,7 @@
     )
     .await;
 
-    let elapsed = timer.elapsed();
+    let elapsed = elapsed(&now);
 
     app.server_metrics().aggregate_job_latency(elapsed);
 
Differences in crates/daphne-worker/src/aggregator/roles/leader.rs
--- crates/daphne-server/src/roles/leader.rs
+++ crates/daphne-worker/src/aggregator/roles/leader.rs
@@ -1,34 +1,36 @@
 // Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
 // SPDX-License-Identifier: BSD-3-Clause
 
-use std::{borrow::Cow, time::Instant};
-
+use std::borrow::Cow;
+
+use crate::{aggregator::App, elapsed};
 use axum::{async_trait, http::Method};
 use daphne::{
     constants::{DapMediaType, DapRole},
     error::DapAbort,
     fatal_error,
     messages::{BatchId, BatchSelector, Collection, CollectionJobId, Report, TaskId},
-    roles::{leader::WorkItem, DapAggregator, DapLeader},
+    roles::{leader::WorkItem, DapAggregator as _, DapLeader},
     DapAggregationParam, DapCollectionJob, DapError, DapRequestMeta, DapResponse, DapVersion,
 };
 use daphne_service_utils::http_headers;
-use http::StatusCode;
+use http::{header, HeaderMap, HeaderName, HeaderValue, StatusCode};
 use prio::codec::ParameterizedEncode;
 use tracing::{error, info};
 use url::Url;
 
 #[async_trait]
-impl DapLeader for crate::App {
+impl DapLeader for App {
     async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError> {
         let task_config = self
             .get_task_config_for(task_id)
             .await?
             .ok_or(DapAbort::UnrecognizedTask { task_id: *task_id })?;
 
-        self.test_leader_state
-            .lock()
-            .await
+        worker::console_log!("{:?}", &*self.test_leader_state.lock().unwrap());
+        self.test_leader_state
+            .lock()
+            .unwrap()
             .put_report(task_id, &task_config, report.clone())
     }
 
@@ -40,9 +42,10 @@
                 task_id: *task_id,
             }))?;
 
-        self.test_leader_state
-            .lock()
-            .await
+        worker::console_log!("{:?}", &*self.test_leader_state.lock().unwrap());
+        self.test_leader_state
+            .lock()
+            .unwrap()
             .current_batch(task_id, &task_config)
     }
 
@@ -58,7 +61,8 @@
             .await?
             .ok_or(DapAbort::UnrecognizedTask { task_id: *task_id })?;
 
-        self.test_leader_state.lock().await.init_collect_job(
+        worker::console_log!("{:?}", &*self.test_leader_state.lock().unwrap());
+        self.test_leader_state.lock().unwrap().init_collect_job(
             task_id,
             &task_config,
             coll_job_id,
@@ -72,9 +76,10 @@
         task_id: &TaskId,
         coll_job_id: &CollectionJobId,
     ) -> Result<DapCollectionJob, DapError> {
-        self.test_leader_state
-            .lock()
-            .await
+        worker::console_log!("{:?}", &*self.test_leader_state.lock().unwrap());
+        self.test_leader_state
+            .lock()
+            .unwrap()
             .poll_collect_job(task_id, coll_job_id)
     }
 
@@ -84,18 +89,24 @@
         coll_job_id: &CollectionJobId,
         collection: &Collection,
     ) -> Result<(), DapError> {
-        self.test_leader_state
-            .lock()
-            .await
+        worker::console_log!("{:?}", &*self.test_leader_state.lock().unwrap());
+        self.test_leader_state
+            .lock()
+            .unwrap()
             .finish_collect_job(task_id, coll_job_id, collection)
     }
 
     async fn dequeue_work(&self, num_items: usize) -> Result<Vec<WorkItem>, DapError> {
-        self.test_leader_state.lock().await.dequeue_work(num_items)
+        worker::console_log!("{:?}", &*self.test_leader_state.lock().unwrap());
+        self.test_leader_state
+            .lock()
+            .unwrap()
+            .dequeue_work(num_items)
     }
 
     async fn enqueue_work(&self, items: Vec<WorkItem>) -> Result<(), DapError> {
-        self.test_leader_state.lock().await.enqueue_work(items)
+        worker::console_log!("{:?}", &*self.test_leader_state.lock().unwrap());
+        self.test_leader_state.lock().unwrap().enqueue_work(items)
     }
 
     async fn send_http_post<P>(
@@ -123,7 +134,8 @@
     }
 }
 
-impl crate::App {
+impl App {
+    #[worker::send]
     async fn send_http<P>(
         &self,
         meta: DapRequestMeta,
@@ -134,7 +146,6 @@
     where
         P: Send + ParameterizedEncode<DapVersion>,
     {
-        use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
         let content_type = meta
             .media_type
             .and_then(|mt| mt.as_str_for_version(meta.version))
@@ -211,19 +222,19 @@
             )
             .headers(headers);
 
-        let start = Instant::now();
+        let start = worker::Date::now();
         let reqwest_resp = req_builder
             .send()
             .await
             .map_err(|e| fatal_error!(err = ?e, "failed to send request to the helper"))?;
-        info!("request to {} completed in {:?}", url, start.elapsed());
+        info!("request to {} completed in {:?}", url, elapsed(&start));
         let status = reqwest_resp.status();
 
         if status.is_success() {
             // Translate the reqwest response into a Worker response.
             let media_type = reqwest_resp
                 .headers()
-                .get_all(reqwest::header::CONTENT_TYPE)
+                .get_all(header::CONTENT_TYPE)
                 .into_iter()
                 .filter_map(|h| h.to_str().ok())
                 .find_map(|h| DapMediaType::from_str_for_version(meta.version, h))
@@ -244,9 +255,7 @@
             error!("{}: request failed: {:?}", url, reqwest_resp);
             match status {
                 StatusCode::BAD_REQUEST => {
-                    if let Some(content_type) =
-                        reqwest_resp.headers().get(reqwest::header::CONTENT_TYPE)
-                    {
+                    if let Some(content_type) = reqwest_resp.headers().get(header::CONTENT_TYPE) {
                         if content_type == "application/problem+json" {
                             error!(
                                 "Problem details: {}",
Differences in crates/daphne-worker/src/aggregator/roles/aggregator.rs
--- crates/daphne-server/src/roles/aggregator.rs
+++ crates/daphne-worker/src/aggregator/roles/aggregator.rs
@@ -1,15 +1,17 @@
 // Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
 // SPDX-License-Identifier: BSD-3-Clause
 
-use std::{future::ready, num::NonZeroUsize, ops::Range, time::SystemTime};
-
-use axum::async_trait;
+use crate::{
+    aggregator::App,
+    storage::kv::{self, KvGetOptions},
+};
 use daphne::{
     audit_log::AuditLog,
+    constants::DapAggregatorRole,
     error::DapAbort,
     fatal_error,
-    hpke::{self, info_and_aad, HpkeConfig, HpkeProvider, HpkeReceiverConfig},
-    messages::{self, BatchId, BatchSelector, HpkeCiphertext, TaskId, Time},
+    hpke::{self, info_and_aad, HpkeConfig, HpkeReceiverConfig},
+    messages::{self, BatchId, BatchSelector, HpkeCiphertext, TaskId},
     metrics::DaphneMetrics,
     roles::{
         aggregator::{MergeAggShareError, TaskprovConfig},
@@ -21,16 +23,13 @@
 use daphne_service_utils::durable_requests::bindings::{
     self, AggregateStoreMergeOptions, AggregateStoreMergeReq, AggregateStoreMergeResp,
 };
-use futures::{future::try_join_all, StreamExt, TryFutureExt, TryStreamExt};
+use futures::{future::try_join_all, StreamExt as _, TryFutureExt as _, TryStreamExt as _};
 use mappable_rc::Marc;
-
-use crate::{
-    roles::fetch_replay_protection_override,
-    storage_proxy_connection::kv::{self, KvGetOptions},
-};
-
-#[async_trait]
-impl DapAggregator for crate::App {
+use std::{num::NonZeroUsize, ops::Range};
+use worker::send::SendFuture;
+
+#[async_trait::async_trait]
+impl DapAggregator for App {
     #[tracing::instrument(skip(self, task_config, agg_share_span))]
     async fn try_put_agg_share_span(
         &self,
@@ -40,7 +39,7 @@
     ) -> DapAggregateSpan<Result<(), MergeAggShareError>> {
         let durable = self.durable();
 
-        let replay_protection = fetch_replay_protection_override(self.kv()).await;
+        let replay_protection = super::fetch_replay_protection_override(self.kv()).await;
 
         futures::stream::iter(agg_share_span)
             .map(|(bucket, (agg_share, report_metadatas))| async {
@@ -184,27 +183,28 @@
         task_id: &TaskId,
         task_config: taskprov::DapTaskConfigNeedsOptIn,
     ) -> Result<DapTaskConfig, DapError> {
-        let param = self
-            .kv()
-            .get_or_insert_with::<kv::prefix::TaskprovOptInParam, _, _>(
-                task_id,
-                &KvGetOptions::default(),
-                || async {
-                    let global_config = self.get_global_config().await?;
-                    Ok::<_, DapError>(taskprov::OptInParam {
-                        not_before: self.get_current_time(),
-                        num_agg_span_shards: global_config.default_num_agg_span_shards,
-                    })
-                },
-                Some(task_config.task_expiration),
-            )
-            .await
-            .map_err(|e| match &*e {
-                kv::GetOrInsertError::Other(e) => e.clone(),
-                kv::GetOrInsertError::StorageProxy(e) => {
-                    fatal_error!(err = ?e, "failed to get TaskprovOptInParam from kv")
-                }
-            })?;
+        let param = SendFuture::new(
+            self.kv()
+                .get_or_insert_with::<kv::prefix::TaskprovOptInParam, _, _>(
+                    task_id,
+                    &KvGetOptions::default(),
+                    || async {
+                        let global_config = self.get_global_config().await?;
+                        Ok::<_, DapError>(taskprov::OptInParam {
+                            not_before: self.get_current_time(),
+                            num_agg_span_shards: global_config.default_num_agg_span_shards,
+                        })
+                    },
+                    Some(task_config.task_expiration),
+                ),
+        )
+        .await
+        .map_err(|e| match e {
+            kv::GetOrInsertError::Other(e) => e.clone(),
+            kv::GetOrInsertError::StorageProxy(e) => {
+                fatal_error!(err = ?e, "failed to get TaskprovOptInParam from kv")
+            }
+        })?;
 
         Ok(task_config.into_opted_in(&param))
     }
@@ -217,17 +217,16 @@
         let expiration_time = task_config.not_after;
 
         match self.service_config.role {
-            daphne::constants::DapAggregatorRole::Leader => {
-                self.kv()
-                    .put_with_expiration::<kv::prefix::TaskConfig>(
-                        task_id,
-                        task_config,
-                        expiration_time,
-                    )
-                    .await
-                    .map_err(|e| fatal_error!(err = ?e, "failed to put the a task config in kv"))?;
+            DapAggregatorRole::Leader => {
+                SendFuture::new(self.kv().put_with_expiration::<kv::prefix::TaskConfig>(
+                    task_id,
+                    task_config,
+                    expiration_time,
+                ))
+                .await
+                .map_err(|e| fatal_error!(err = ?e, "failed to put the a task config in kv"))?;
             }
-            daphne::constants::DapAggregatorRole::Helper => {
+            DapAggregatorRole::Helper => {
                 self.kv()
                     .only_cache_put::<kv::prefix::TaskConfig>(task_id, task_config)
                     .await;
@@ -246,11 +245,8 @@
             .map_err(|e| fatal_error!(err = ?e, "failed to get a task config from kv: {task_id}"))
     }
 
-    fn get_current_time(&self) -> Time {
-        SystemTime::now()
-            .duration_since(SystemTime::UNIX_EPOCH)
-            .unwrap()
-            .as_secs()
+    fn get_current_time(&self) -> messages::Time {
+        worker::Date::now().as_millis() / 1000
     }
 
     async fn is_batch_overlapping(
@@ -280,7 +276,7 @@
                         .send()
                 })
                 .buffer_unordered(usize::MAX)
-                .try_any(ready)
+                .try_any(std::future::ready)
                 .await
                 .map_err(
                     |e| fatal_error!(err = ?e, "failed to check if agg shares are collected"),
@@ -349,10 +345,7 @@
     }
 
     fn valid_report_time_range(&self) -> Range<messages::Time> {
-        let now = SystemTime::now()
-            .duration_since(SystemTime::UNIX_EPOCH)
-            .expect("now should always be after unix epoch")
-            .as_secs();
+        let now = self.get_current_time();
 
         let start = now.saturating_sub(self.service_config.report_storage_epoch_duration);
         let end = now.saturating_add(self.service_config.report_storage_max_future_time_skew);
@@ -373,8 +366,8 @@
     }
 }
 
-#[async_trait]
-impl HpkeProvider for crate::App {
+#[async_trait::async_trait]
+impl hpke::HpkeProvider for App {
     type WrappedHpkeConfig<'s> = Marc<HpkeConfig>;
     type ReceiverConfigs<'s> = HpkeDecrypter;
 
Differences in crates/daphne-worker/src/aggregator/roles/mod.rs
--- crates/daphne-server/src/roles/mod.rs
+++ crates/daphne-worker/src/aggregator/roles/mod.rs
@@ -1,18 +1,17 @@
 // Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
 // SPDX-License-Identifier: BSD-3-Clause
-
-use daphne::{constants::DapRole, messages::TaskId, ReplayProtection};
-use daphne_service_utils::bearer_token::BearerToken;
-
-use crate::storage_proxy_connection::{
-    self,
-    kv::{self, Kv, KvGetOptions},
-};
-use mappable_rc::Marc;
 
 mod aggregator;
 mod helper;
 mod leader;
+
+use crate::storage::{
+    self,
+    kv::{self, KvGetOptions},
+    Kv,
+};
+use daphne::{constants::DapRole, messages::TaskId, ReplayProtection};
+use daphne_service_utils::bearer_token::BearerToken;
 
 pub async fn fetch_replay_protection_override(kv: Kv<'_>) -> ReplayProtection {
     let skip_replay_protection = kv
@@ -57,7 +56,7 @@
         sender: DapRole,
         task_id: TaskId,
         token: BearerToken,
-    ) -> Result<Option<BearerToken>, storage_proxy_connection::Error> {
+    ) -> Result<Option<BearerToken>, storage::Error> {
         self.kv
             .put_if_not_exists::<kv::prefix::KvBearerToken>(&(sender, task_id).into(), token)
             .await
@@ -75,7 +74,7 @@
         sender: DapRole,
         task_id: TaskId,
         token: &BearerToken,
-    ) -> Result<bool, Marc<storage_proxy_connection::Error>> {
+    ) -> Result<bool, storage::Error> {
         self.kv
             .peek::<kv::prefix::KvBearerToken, _, _>(
                 &(sender, task_id).into(),
@@ -92,7 +91,7 @@
         &self,
         sender: DapRole,
         task_id: TaskId,
-    ) -> Result<Option<BearerToken>, Marc<storage_proxy_connection::Error>> {
+    ) -> Result<Option<BearerToken>, storage::Error> {
         self.kv
             .get_cloned::<kv::prefix::KvBearerToken>(
                 &(sender, task_id).into(),
@@ -106,12 +105,14 @@
 
 #[cfg(feature = "test-utils")]
 mod test_utils {
+    use super::super::App;
+    use crate::{storage::kv, storage_proxy};
     use daphne::{
         constants::{DapAggregatorRole, DapRole},
         fatal_error,
         hpke::{HpkeConfig, HpkeReceiverConfig},
         messages::decode_base64url_vec,
-        roles::DapAggregator,
+        roles::DapAggregator as _,
         vdaf::{Prio3Config, VdafConfig},
         DapBatchMode, DapError, DapTaskConfig, DapVersion,
     };
@@ -122,39 +123,20 @@
     use prio::codec::Decode;
     use std::num::NonZeroUsize;
 
-    use crate::storage_proxy_connection::kv;
-
-    impl crate::App {
+    impl App {
         pub(crate) async fn internal_delete_all(&self) -> Result<(), DapError> {
-            self.test_leader_state.lock().await.delete_all();
-
-            use daphne_service_utils::durable_requests::PURGE_STORAGE;
-            self.kv_state.reset().await;
-
-            self.http
-                .delete(self.storage_proxy_config.url.join(PURGE_STORAGE).unwrap())
-                .bearer_auth(self.storage_proxy_config.auth_token.as_str())
-                .send()
-                .await
-                .map_err(
-                    |e| fatal_error!(err = ?e, "failed to send delete request to storage proxy"),
-                )?
-                .error_for_status()
-                .map_err(|e| fatal_error!(err = ?e, "failed to clear storage proxy"))?;
-
-            Ok(())
-        }
-
-        pub(crate) async fn storage_ready_check(&self) -> Result<(), DapError> {
-            use daphne_service_utils::durable_requests::STORAGE_READY;
-            self.http
-                .get(self.storage_proxy_config.url.join(STORAGE_READY).unwrap())
-                .bearer_auth(self.storage_proxy_config.auth_token.as_str())
-                .send()
-                .await
-                .map_err(|e| fatal_error!(err = ?e, "failed to send ready check request to storage proxy"))?
-                .error_for_status()
-                .map_err(|e| fatal_error!(err = ?e, "storage proxy is not ready"))?;
+            tracing::info!("deleting leader state");
+            self.test_leader_state.lock().unwrap().delete_all();
+
+            tracing::info!("deleting kv state");
+            // use daphne_service_utils::durable_requests::PURGE_STORAGE;
+            self.kv_state.reset();
+
+            tracing::info!("purging storage");
+            storage_proxy::storage_purge(&self.env.0)
+                .await
+                .map_err(|e| fatal_error!(err = ?e, "failed to purge storage"))?;
+
             Ok(())
         }
 
Differences in crates/daphne-worker/src/aggregator/roles/helper.rs
--- crates/daphne-server/src/roles/helper.rs
+++ crates/daphne-worker/src/aggregator/roles/helper.rs
@@ -1,8 +1,7 @@
 // Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
 // SPDX-License-Identifier: BSD-3-Clause
 
-use axum::async_trait;
+use crate::aggregator::App;
 use daphne::roles::DapHelper;
 
-#[async_trait]
-impl DapHelper for crate::App {}
+impl DapHelper for App {}

everything inside daphne-worker/src/storage is new and it serves as a substitute for the daphne-server/src/storage-proxy-connection module.

Differences in crates/daphne-worker/src/storage/mod.rs
--- crates/daphne-server/src/storage_proxy_connection/mod.rs
+++ crates/daphne-worker/src/storage/mod.rs
@@ -3,43 +3,36 @@
 
 pub(crate) mod kv;
 
-use std::fmt::Debug;
-
+use crate::storage_proxy;
 use axum::http::StatusCode;
 use daphne_service_utils::{
-    capnproto_payload::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _},
-    durable_requests::{bindings::DurableMethod, DurableRequest, ObjectIdFrom, DO_PATH_PREFIX},
+    capnproto_payload::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt},
+    durable_requests::{bindings::DurableMethod, DurableRequest, ObjectIdFrom},
 };
+pub(crate) use kv::Kv;
 use serde::de::DeserializeOwned;
-
-pub(crate) use kv::Kv;
-
-use crate::StorageProxyConfig;
+use std::fmt::Debug;
+use worker::Env;
 
 #[derive(Debug, thiserror::Error)]
 pub(crate) enum Error {
     #[error("serialization error: {0}")]
     Serde(#[from] serde_json::Error),
-    #[error("network error: {0}")]
-    Reqwest(#[from] reqwest::Error),
+    #[error("worker error: {0}")]
+    Worker(#[from] worker::Error),
     #[error("http error. request returned status code {status} with the body {body}")]
     Http { status: StatusCode, body: String },
 }
 
 #[derive(Clone, Copy)]
 pub(crate) struct Do<'h> {
-    config: &'h StorageProxyConfig,
-    http: &'h reqwest::Client,
+    env: &'h Env,
     retry: bool,
 }
 
 impl<'h> Do<'h> {
-    pub fn new(config: &'h StorageProxyConfig, client: &'h reqwest::Client) -> Self {
-        Self {
-            config,
-            http: client,
-            retry: false,
-        }
+    pub fn new(env: &'h Env) -> Self {
+        Self { env, retry: false }
     }
 
     #[expect(dead_code)]
@@ -68,27 +61,24 @@
             path = ?self.path,
             "requesting DO",
         );
-        let url = self
-            .durable
-            .config
-            .url
-            .join(&format!("{DO_PATH_PREFIX}{}", self.path.to_uri()))
-            .unwrap();
-        let resp = self
-            .durable
-            .http
-            .post(url)
-            .body(self.request.into_bytes())
-            .bearer_auth(self.durable.config.auth_token.as_str())
-            .send()
-            .await?;
+        let resp = storage_proxy::handle_do_request(
+            self.durable.env,
+            Default::default(),
+            self.path.to_uri(),
+            self.request,
+            |_, _, _| {},
+        )
+        .await?;
 
-        if resp.status().is_success() {
-            Ok(resp.json().await?)
+        use http_body_util::BodyExt;
+        let (resp, body) = resp.into_parts();
+        let body = body.collect().await?.to_bytes();
+        if resp.status.is_success() {
+            Ok(serde_json::from_slice(&body)?)
         } else {
             Err(Error::Http {
-                status: resp.status(),
-                body: resp.text().await?,
+                status: resp.status,
+                body: String::from_utf8_lossy(&body).into_owned(),
             })
         }
     }
Differences in crates/daphne-worker/src/storage/kv/mod.rs
--- crates/daphne-server/src/storage_proxy_connection/kv/mod.rs
+++ crates/daphne-worker/src/storage/kv/mod.rs
@@ -2,45 +2,37 @@
 // SPDX-License-Identifier: BSD-3-Clause
 
 mod cache;
-mod request_coalescer;
-
-use std::{any::Any, fmt::Display, future::Future};
-
-use axum::http::StatusCode;
+// mod request_coalescer;
+
+use std::{any::Any, fmt::Display, future::Future, sync::RwLock};
+
 use daphne_service_utils::durable_requests::KV_PATH_PREFIX;
 use mappable_rc::Marc;
 use serde::{de::DeserializeOwned, Serialize};
-use tokio::sync::RwLock;
 use tracing::{info_span, Instrument};
 
-use crate::StorageProxyConfig;
-
 use super::Error;
+use crate::storage_proxy;
 use cache::Cache;
 use daphne::messages::Time;
-use daphne_service_utils::http_headers::STORAGE_PROXY_PUT_KV_EXPIRATION;
+use worker::send::SendWrapper;
 
 #[derive(Default)]
 pub struct State {
     cache: RwLock<Cache>,
-    coalescer: request_coalescer::RequestCoalescer,
 }
 
 impl State {
     #[cfg(feature = "test-utils")]
-    pub async fn reset(&self) {
-        let Self { cache, coalescer } = self;
-
-        let clear_cache = async {
-            *cache.write().await = Default::default();
-        };
-        tokio::join!(clear_cache, coalescer.reset());
+    pub fn reset(&self) {
+        let Self { cache } = self;
+
+        *cache.write().unwrap() = Default::default();
     }
 }
 
 pub(crate) struct Kv<'h> {
-    config: &'h StorageProxyConfig,
-    http: &'h reqwest::Client,
+    env: &'h SendWrapper<worker::Env>,
     state: &'h State,
 }
 
@@ -59,15 +51,16 @@
 
     use daphne::{
         constants::DapRole,
+        hpke::HpkeReceiverConfig,
         messages::{Base64Encode, TaskId},
         taskprov, DapTaskConfig, DapVersion,
     };
     use daphne_service_utils::bearer_token::BearerToken;
     use serde::{de::DeserializeOwned, Serialize};
 
-    use crate::config::HpkeRecieverConfigList;
-
     use super::KvPrefix;
+
+    pub type HpkeRecieverConfigList = Vec<HpkeReceiverConfig>;
 
     #[derive(Debug)]
     pub struct GlobalConfigOverride<V>(PhantomData<V>);
@@ -181,23 +174,15 @@
 }
 
 impl<'h> Kv<'h> {
-    pub fn new(
-        config: &'h StorageProxyConfig,
-        client: &'h reqwest::Client,
-        state: &'h State,
-    ) -> Self {
-        Self {
-            config,
-            http: client,
-            state,
-        }
+    pub fn new(env: &'h SendWrapper<worker::Env>, state: &'h State) -> Self {
+        Self { env, state }
     }
 
     pub async fn get<P>(
         &self,
         key: &P::Key,
         opt: &KvGetOptions,
-    ) -> Result<Option<Marc<P::Value>>, Marc<Error>>
+    ) -> Result<Option<Marc<P::Value>>, Error>
     where
         P: KvPrefix,
         P::Key: std::fmt::Debug,
@@ -209,7 +194,7 @@
         &self,
         key: &P::Key,
         opt: &KvGetOptions,
-    ) -> Result<Option<P::Value>, Marc<Error>>
+    ) -> Result<Option<P::Value>, Error>
     where
         P: KvPrefix,
         P::Key: std::fmt::Debug,
@@ -223,15 +208,16 @@
         key: &P::Key,
         opt: &KvGetOptions,
         mapper: F,
-    ) -> Result<Option<Marc<R>>, Marc<Error>>
+    ) -> Result<Option<Marc<R>>, Error>
     where
         P: KvPrefix,
         P::Key: std::fmt::Debug,
         F: for<'s> FnOnce(&'s P::Value) -> Option<&'s R>,
         R: Send + Sync + 'static,
     {
-        self.get_coalesced::<P, _, _>(key, opt, |marc| Marc::try_map(marc, mapper).ok())
+        self.get_internal::<P, _, _>(key, opt, Some)
             .await
+            .map(|opt| opt.flatten().map(|marc| Marc::try_map(marc, mapper).ok()))
             .map(Option::flatten)
     }
 
@@ -241,62 +227,35 @@
         opt: &KvGetOptions,
         default: impl FnOnce() -> Fut,
         expiration: Option<Time>,
-    ) -> Result<Marc<P::Value>, Marc<GetOrInsertError<E>>>
+    ) -> Result<Marc<P::Value>, GetOrInsertError<E>>
     where
         P: KvPrefix,
         P::Key: std::fmt::Debug,
         E: Send + Sync + 'static,
         Fut: Future<Output = Result<P::Value, E>>,
     {
-        self.state
-            .coalescer
-            .coalesce(Self::to_key::<P>(key), || async {
-                if let Some(v) = self.get_internal::<P, _, _>(key, opt, |marc| marc).await? {
-                    return Ok(Some(v));
-                }
-                let default = default().await.map_err(GetOrInsertError::Other)?;
-                let cached = self.put_internal::<P>(key, default, expiration).await?;
-                Ok(Some(cached))
-            })
+        if let Some(v) = self.get_internal::<P, _, _>(key, opt, |marc| marc).await? {
+            return Ok(v);
+        }
+        let default = default().await.map_err(GetOrInsertError::Other)?;
+        let cached = self.put_internal::<P>(key, default, expiration).await?;
+        Ok(cached)
+    }
+
+    pub async fn peek<P, R, F>(
+        &self,
+        key: &P::Key,
+        opt: &KvGetOptions,
+        peeker: F,
+    ) -> Result<Option<R>, Error>
+    where
+        P: KvPrefix,
+        P::Key: std::fmt::Debug,
+        F: FnOnce(&P::Value) -> R,
+    {
+        self.get_internal::<P, _, _>(key, opt, Some)
             .await
-            .map(|v| v.unwrap()) // all paths of the previous closure return Some
-    }
-
-    pub async fn peek<P, R, F>(
-        &self,
-        key: &P::Key,
-        opt: &KvGetOptions,
-        peeker: F,
-    ) -> Result<Option<R>, Marc<Error>>
-    where
-        P: KvPrefix,
-        P::Key: std::fmt::Debug,
-        F: FnOnce(&P::Value) -> R,
-    {
-        self.get_coalesced::<P, _, _>(key, opt, |marc| peeker(&marc))
-            .await
-    }
-
-    async fn get_coalesced<P, R, F>(
-        &self,
-        key: &P::Key,
-        opt: &KvGetOptions,
-        mapper: F,
-    ) -> Result<Option<R>, Marc<Error>>
-    where
-        P: KvPrefix,
-        P::Key: std::fmt::Debug,
-        F: FnOnce(Marc<P::Value>) -> R,
-    {
-        self.state
-            .coalescer
-            .coalesce(Self::to_key::<P>(key), || async {
-                self.get_internal::<P, _, _>(key, opt, Some)
-                    .await
-                    .map(Option::flatten)
-            })
-            .await
-            .map(|opt_v| opt_v.map(mapper))
+            .map(|opt| opt.flatten().map(|marc| peeker(&marc)))
     }
 
     async fn get_internal<P, R, F>(
@@ -312,7 +271,7 @@
     {
         let key = Self::to_key::<P>(key);
         tracing::debug!(key, "GET");
-        match self.state.cache.read().await.get::<P>(&key) {
+        match self.state.cache.read().unwrap().get::<P>(&key) {
             cache::CacheResult::Miss => {}
             cache::CacheResult::Hit(t) => return Ok(t.map(mapper)),
             cache::CacheResult::MismatchedType => {
@@ -329,34 +288,22 @@
             prefix = std::any::type_name::<P>()
         );
         async {
-            let resp = self
-                .http
-                .get(self.config.url.join(&key).unwrap())
-                .bearer_auth(self.config.auth_token.as_str())
-                .send()
-                .await?;
-            if resp.status() == StatusCode::NOT_FOUND {
+            if let Some(v) = storage_proxy::kv_get(self.env, &key).await? {
+                let t = Marc::new(serde_json::from_slice::<P::Value>(&v)?);
+                let r = mapper(t.clone());
+                self.state.cache.write().unwrap().put::<P>(key, Some(t));
+                Ok(Some(r))
+            } else {
                 if opt.cache_not_found {
-                    self.state.cache.write().await.put::<P>(key, None);
+                    self.state.cache.write().unwrap().put::<P>(key, None);
                 }
                 Ok(None)
-            } else {
-                let resp = resp.error_for_status()?;
-                let t = Marc::new(resp.json::<P::Value>().await?);
-                let r = mapper(t.clone());
-                self.state.cache.write().await.put::<P>(key, Some(t));
-                Ok(Some(r))
             }
         }
         .instrument(span)
         .await
     }
 
-    #[tracing::instrument(
-        name = "kv_put",
-        skip_all,
-        fields(key, prefix = std::any::type_name::<P>()),
-    )]
     pub async fn put_internal<P>(
         &self,
         key: &P::Key,
@@ -371,23 +318,19 @@
         let key = Self::to_key::<P>(key);
         tracing::debug!(key, "PUT");
 
-        let mut request = self
-            .http
-            .post(self.config.url.join(&key).unwrap())
-            .bearer_auth(self.config.auth_token.as_str())
-            .body(serde_json::to_vec(&value).unwrap());
-
-        if let Some(expiration) = expiration {
-            request = request.header(STORAGE_PROXY_PUT_KV_EXPIRATION, expiration);
-        }
-
-        request.send().await?.error_for_status()?;
+        storage_proxy::kv_put(
+            self.env,
+            expiration,
+            &key,
+            &serde_json::to_vec(&value).unwrap(),
+        )
+        .await?;
 
         let value = Marc::new(value);
         self.state
             .cache
             .write()
-            .await
+            .unwrap()
             .put::<P>(key, Some(value.clone()));
         Ok(value)
     }
@@ -406,6 +349,7 @@
         self.put_internal::<P>(key, value, Some(expiration)).await
     }
 
+    #[cfg_attr(not(feature = "test-utils"), expect(dead_code))]
     pub async fn put<P>(&self, key: &P::Key, value: P::Value) -> Result<Marc<P::Value>, Error>
     where
         P: KvPrefix,
@@ -418,11 +362,6 @@
     /// Stores a value in kv if it doesn't already exist.
     ///
     /// If the value already exists, returns the passed in value inside the Ok variant.
-    #[tracing::instrument(
-        name = "kv_put_if_not_exists",
-        skip_all,
-        fields(key, prefix = std::any::type_name::<P>()),
-    )]
     pub async fn put_if_not_exists_internal<P>(
         &self,
         key: &P::Key,
@@ -437,31 +376,26 @@
 
         tracing::debug!(key, "PUT if not exists");
 
-        let mut request = self
-            .http
-            .put(self.config.url.join(&key).unwrap())
-            .bearer_auth(self.config.auth_token.as_str())
-            .body(serde_json::to_vec(&value).unwrap());
-
-        if let Some(expiration) = expiration {
-            request = request.header(STORAGE_PROXY_PUT_KV_EXPIRATION, expiration);
-        }
-
-        let response = request.send().await?;
-
-        if response.status() == StatusCode::CONFLICT {
-            Ok(Some(value))
-        } else {
-            response.error_for_status()?;
+        let inserted = storage_proxy::kv_put_if_not_exists(
+            self.env,
+            expiration,
+            &key,
+            &serde_json::to_vec(&value).unwrap(),
+        )
+        .await?;
+        if inserted {
             self.state
                 .cache
                 .write()
-                .await
+                .unwrap()
                 .put::<P>(key, Some(value.into()));
             Ok(None)
-        }
-    }
-
+        } else {
+            Ok(Some(value))
+        }
+    }
+
+    #[cfg_attr(not(feature = "test-utils"), expect(dead_code))]
     pub async fn put_if_not_exists_with_expiration<P>(
         &self,
         key: &P::Key,
@@ -477,6 +411,7 @@
             .await
     }
 
+    #[cfg_attr(not(feature = "test-utils"), expect(dead_code))]
     pub async fn put_if_not_exists<P>(
         &self,
         key: &P::Key,
@@ -500,7 +435,7 @@
         self.state
             .cache
             .write()
-            .await
+            .unwrap()
             .put::<P>(key, Some(value.into()));
     }
 
Differences in crates/daphne-worker/src/storage/kv/cache.rs
--- crates/daphne-server/src/storage_proxy_connection/kv/cache.rs
+++ crates/daphne-worker/src/storage/kv/cache.rs
@@ -1,21 +1,19 @@
 // Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
 // SPDX-License-Identifier: BSD-3-Clause
 
-use std::{
-    any::Any,
-    collections::HashMap,
-    time::{Duration, Instant},
-};
+use std::{any::Any, collections::HashMap, time::Duration};
 
 use mappable_rc::Marc;
 
 use super::KvPrefix;
+use crate::elapsed;
+use worker::send::SendWrapper;
 
 const CACHE_VALUE_LIFETIME: Duration = Duration::from_secs(60 * 5);
 
 struct CacheLine {
     /// Time at which the cache item was set.
-    ts: Instant,
+    ts: SendWrapper<worker::Date>,
 
     /// Either the value or an indication that no value was found.
     entry: Option<Marc<dyn Any + Send + Sync + 'static>>,
@@ -48,7 +46,7 @@
         match self.kv.get(P::PREFIX) {
             Some(cache) => match cache.get(key) {
                 // Cache hit
-                Some(CacheLine { ts, entry }) if ts.elapsed() < CACHE_VALUE_LIFETIME => entry
+                Some(CacheLine { ts, entry }) if elapsed(ts) < CACHE_VALUE_LIFETIME => entry
                     .as_ref()
                     .map(|entry| Marc::try_map(entry.clone(), |v| v.downcast_ref::<P::Value>()))
                     .transpose() // bring out the try_map error
@@ -70,7 +68,7 @@
         self.kv.entry(P::PREFIX).or_default().insert(
             key,
             CacheLine {
-                ts: Instant::now(),
+                ts: SendWrapper(worker::Date::now()),
                 entry: entry.map(|value| Marc::map(value, |v| v as &(dyn Any + Send + Sync))),
             },
         );

NOTE: CI fails due to me including a test SINGING_KEY for the aggregator worker

@mendess mendess self-assigned this Jan 6, 2025
@mendess mendess force-pushed the mendess/async-helper branch 6 times, most recently from 8245283 to cf9707f Compare January 6, 2025 16:53
@mendess mendess marked this pull request as ready for review January 6, 2025 16:53
Copy link
Contributor

@cjpatton cjpatton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments only. Thanks for the super thorough description!

@mendess mendess force-pushed the mendess/async-helper branch from cf9707f to 61bb8f7 Compare January 8, 2025 10:56
@mendess mendess force-pushed the mendess/async-helper branch from 61bb8f7 to 69adc1d Compare January 8, 2025 10:57
@mendess mendess force-pushed the mendess/async-helper branch from 69adc1d to 859e80e Compare January 8, 2025 11:14
@mendess mendess merged commit 964bb98 into main Jan 8, 2025
3 of 4 checks passed
@mendess mendess deleted the mendess/async-helper branch January 8, 2025 13:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants