Skip to content

Commit 2894ff3

Browse files
committed
Hash the encoded aggregation request instead of hashing the decoded one
1 parent 0447fd6 commit 2894ff3

File tree

14 files changed

+181
-116
lines changed

14 files changed

+181
-116
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/daphne-server/src/roles/helper.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
use axum::async_trait;
55
use daphne::{
6-
messages::{AggregationJobId, AggregationJobInitReq, TaskId},
6+
messages::{request::AggregationJobRequestHash, AggregationJobId, TaskId},
77
roles::DapHelper,
88
DapError, DapVersion,
99
};
@@ -15,7 +15,7 @@ impl DapHelper for crate::App {
1515
_id: AggregationJobId,
1616
_version: DapVersion,
1717
_task_id: &TaskId,
18-
_req: &AggregationJobInitReq,
18+
_req: &AggregationJobRequestHash,
1919
) -> Result<(), DapError> {
2020
// the server implementation can't check for this
2121
Ok(())

crates/daphne-server/src/router/extractor.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use daphne::{
1313
error::DapAbort,
1414
fatal_error,
1515
messages::{
16-
request::{CollectionPollReq, RequestBody},
16+
request::{CollectionPollReq, HashedAggregationJobReq, RequestBody},
1717
taskprov::TaskprovAdvertisement,
1818
AggregateShareReq, AggregationJobInitReq, CollectionReq, Report, TaskId,
1919
},
@@ -60,6 +60,17 @@ impl_decode_from_dap_http_body!(
6060
CollectionReq,
6161
);
6262

63+
impl DecodeFromDapHttpBody for HashedAggregationJobReq {
64+
fn decode_from_http_body(bytes: Bytes, meta: &DapRequestMeta) -> Result<Self, DapAbort> {
65+
let mut cursor = Cursor::new(bytes.as_ref());
66+
// Check that media type matches.
67+
meta.get_checked_media_type(DapMediaType::AggregationJobInitReq)?;
68+
// Decode the body
69+
HashedAggregationJobReq::decode_with_param(&meta.version, &mut cursor)
70+
.map_err(|e| DapAbort::from_codec_error(e, meta.task_id))
71+
}
72+
}
73+
6374
/// Using `()` ignores the body of a request.
6475
impl DecodeFromDapHttpBody for CollectionPollReq {
6576
fn decode_from_http_body(_bytes: Bytes, _meta: &DapRequestMeta) -> Result<Self, DapAbort> {

crates/daphne-server/src/router/helper.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use axum::{
88
routing::{post, put},
99
};
1010
use daphne::{
11-
messages::{AggregateShareReq, AggregationJobInitReq},
11+
messages::{request::HashedAggregationJobReq, AggregateShareReq},
1212
roles::{helper, DapHelper},
1313
};
1414
use http::StatusCode;
@@ -38,7 +38,7 @@ pub(super) fn add_helper_routes(router: super::Router<App>) -> super::Router<App
3838
)]
3939
async fn agg_job(
4040
State(app): State<Arc<App>>,
41-
DapRequestExtractor(req): DapRequestExtractor<FROM_LEADER, AggregationJobInitReq>,
41+
DapRequestExtractor(req): DapRequestExtractor<FROM_LEADER, HashedAggregationJobReq>,
4242
) -> AxumDapResponse {
4343
let timer = std::time::Instant::now();
4444

crates/daphne-service-utils/Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ prio_draft09 = { workspace = true, optional = true }
1818
prio = { workspace = true, optional = true }
1919
serde.workspace = true
2020
url = { workspace = true, optional = true }
21-
ring = { workspace = true, optional = true }
2221

2322
[dev-dependencies]
2423
daphne = { path = "../daphne", default-features = false, features = ["prometheus"] }
@@ -34,7 +33,6 @@ durable_requests = [
3433
"dep:capnpc",
3534
"dep:prio_draft09",
3635
"dep:prio",
37-
"dep:ring"
3836
]
3937
experimental = ["daphne/experimental"]
4038

crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs

Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ use crate::{
77
durable_requests::ObjectIdFrom,
88
};
99
use daphne::{
10-
messages::{AggregationJobId, AggregationJobInitReq, PartialBatchSelector, TaskId},
10+
messages::{AggregationJobId, TaskId},
1111
DapVersion,
1212
};
1313
use serde::{Deserialize, Serialize};
14-
use std::{ops::Deref, slice};
14+
use std::borrow::Cow;
1515

1616
super::define_do_binding! {
1717
const BINDING = "AGGREGATION_JOB_STORE";
@@ -27,73 +27,21 @@ super::define_do_binding! {
2727
}
2828

2929
#[derive(Debug)]
30-
pub struct AggregationJobReqHash(Vec<u8>);
31-
32-
impl Deref for AggregationJobReqHash {
33-
type Target = [u8];
34-
fn deref(&self) -> &Self::Target {
35-
&self.0
36-
}
37-
}
38-
39-
impl From<&AggregationJobInitReq> for AggregationJobReqHash {
40-
fn from(req: &AggregationJobInitReq) -> Self {
41-
let AggregationJobInitReq {
42-
agg_param,
43-
part_batch_sel,
44-
prep_inits,
45-
} = req;
46-
47-
let mut context = ring::digest::Context::new(&ring::digest::SHA256);
48-
context.update(agg_param);
49-
context.update(match part_batch_sel {
50-
PartialBatchSelector::TimeInterval => &[0],
51-
PartialBatchSelector::LeaderSelectedByBatchId { batch_id } => batch_id.as_ref(),
52-
});
53-
for p in prep_inits {
54-
let daphne::messages::PrepareInit {
55-
report_share:
56-
daphne::messages::ReportShare {
57-
report_metadata: daphne::messages::ReportMetadata { id, time },
58-
public_share,
59-
encrypted_input_share:
60-
daphne::messages::HpkeCiphertext {
61-
config_id,
62-
enc,
63-
payload: cypher_text_payload,
64-
},
65-
},
66-
payload,
67-
} = p;
68-
69-
context.update(payload);
70-
context.update(public_share);
71-
context.update(id.as_ref());
72-
context.update(&time.to_be_bytes());
73-
context.update(cypher_text_payload);
74-
context.update(slice::from_ref(config_id));
75-
context.update(enc);
76-
}
77-
Self(context.finish().as_ref().to_vec())
78-
}
79-
}
80-
81-
#[derive(Debug)]
82-
pub struct NewJobRequest {
30+
pub struct NewJobRequest<'h> {
8331
pub id: AggregationJobId,
84-
pub agg_job_hash: AggregationJobReqHash,
32+
pub agg_job_hash: Cow<'h, [u8]>,
8533
}
8634

87-
impl CapnprotoPayloadEncode for NewJobRequest {
35+
impl CapnprotoPayloadEncode for NewJobRequest<'_> {
8836
type Builder<'a> = new_job_request::Builder<'a>;
8937

9038
fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
9139
self.id.encode_to_builder(builder.reborrow().init_id());
92-
builder.set_agg_job_hash(&self.agg_job_hash.0);
40+
builder.set_agg_job_hash(&self.agg_job_hash);
9341
}
9442
}
9543

96-
impl CapnprotoPayloadDecode for NewJobRequest {
44+
impl CapnprotoPayloadDecode for NewJobRequest<'static> {
9745
type Reader<'a> = new_job_request::Reader<'a>;
9846

9947
fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
@@ -102,7 +50,7 @@ impl CapnprotoPayloadDecode for NewJobRequest {
10250
{
10351
Ok(Self {
10452
id: <_>::decode_from_reader(reader.get_id()?)?,
105-
agg_job_hash: AggregationJobReqHash(reader.get_agg_job_hash()?.to_vec()),
53+
agg_job_hash: reader.get_agg_job_hash()?.to_vec().into(),
10654
})
10755
}
10856
}

crates/daphne-worker/src/aggregator/roles/helper.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ use crate::aggregator::App;
55
use daphne::{
66
error::DapAbort,
77
fatal_error,
8-
messages::{AggregationJobId, AggregationJobInitReq, TaskId},
8+
messages::{request::AggregationJobRequestHash, AggregationJobId, TaskId},
99
roles::DapHelper,
1010
DapError, DapVersion,
1111
};
1212
use daphne_service_utils::durable_requests::bindings::aggregation_job_store;
13+
use std::borrow::Cow;
1314

1415
#[axum::async_trait]
1516
impl DapHelper for App {
@@ -18,15 +19,15 @@ impl DapHelper for App {
1819
id: AggregationJobId,
1920
version: DapVersion,
2021
task_id: &TaskId,
21-
req: &AggregationJobInitReq,
22+
req: &AggregationJobRequestHash,
2223
) -> Result<(), DapError> {
2324
let response = self
2425
.durable()
2526
.with_retry()
2627
.request(aggregation_job_store::Command::NewJob, (version, task_id))
2728
.encode(&aggregation_job_store::NewJobRequest {
2829
id,
29-
agg_job_hash: req.into(),
30+
agg_job_hash: Cow::Borrowed(req.get()),
3031
})
3132
.send::<aggregation_job_store::NewJobResponse>()
3233
.await

crates/daphne-worker/src/aggregator/router/extractor.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use daphne::{
1313
error::DapAbort,
1414
fatal_error,
1515
messages::{
16-
request::{CollectionPollReq, RequestBody},
16+
request::{CollectionPollReq, HashedAggregationJobReq, RequestBody},
1717
taskprov::TaskprovAdvertisement,
1818
AggregateShareReq, AggregationJobInitReq, CollectionReq, Report, TaskId,
1919
},
@@ -60,6 +60,17 @@ impl_decode_from_dap_http_body!(
6060
CollectionReq,
6161
);
6262

63+
impl DecodeFromDapHttpBody for HashedAggregationJobReq {
64+
fn decode_from_http_body(bytes: Bytes, meta: &DapRequestMeta) -> Result<Self, DapAbort> {
65+
let mut cursor = Cursor::new(bytes.as_ref());
66+
// Check that media type matches.
67+
meta.get_checked_media_type(DapMediaType::AggregationJobInitReq)?;
68+
// Decode the body
69+
HashedAggregationJobReq::decode_with_param(&meta.version, &mut cursor)
70+
.map_err(|e| DapAbort::from_codec_error(e, meta.task_id))
71+
}
72+
}
73+
6374
/// Using `()` ignores the body of a request.
6475
impl DecodeFromDapHttpBody for CollectionPollReq {
6576
fn decode_from_http_body(_bytes: Bytes, _meta: &DapRequestMeta) -> Result<Self, DapAbort> {

crates/daphne-worker/src/aggregator/router/helper.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use axum::{
88
routing::{post, put},
99
};
1010
use daphne::{
11-
messages::{AggregateShareReq, AggregationJobInitReq},
11+
messages::{request::HashedAggregationJobReq, AggregateShareReq},
1212
roles::{helper, DapHelper},
1313
};
1414
use http::StatusCode;
@@ -39,7 +39,7 @@ pub(super) fn add_helper_routes(router: super::Router<App>) -> super::Router<App
3939
#[worker::send]
4040
async fn agg_job(
4141
State(app): State<Arc<App>>,
42-
DapRequestExtractor(req): DapRequestExtractor<FROM_LEADER, AggregationJobInitReq>,
42+
DapRequestExtractor(req): DapRequestExtractor<FROM_LEADER, HashedAggregationJobReq>,
4343
) -> AxumDapResponse {
4444
let now = worker::Date::now();
4545

crates/daphne/src/messages/request.rs

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use super::{
88
CollectionJobId, CollectionReq, Report,
99
};
1010
use crate::{constants::DapMediaType, error::DapAbort, messages::TaskId, DapVersion};
11+
use prio::codec::{ParameterizedDecode, ParameterizedEncode};
1112

1213
pub trait RequestBody {
1314
type ResourceId;
@@ -24,15 +25,79 @@ macro_rules! impl_req_body {
2425
};
2526
}
2627

28+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29+
#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))]
30+
pub struct AggregationJobRequestHash(Vec<u8>);
31+
32+
impl AggregationJobRequestHash {
33+
pub fn get(&self) -> &[u8] {
34+
&self.0
35+
}
36+
37+
fn hash(bytes: &[u8]) -> Self {
38+
Self(
39+
ring::digest::digest(&ring::digest::SHA256, bytes)
40+
.as_ref()
41+
.to_vec(),
42+
)
43+
}
44+
}
45+
46+
pub struct HashedAggregationJobReq {
47+
pub request: AggregationJobInitReq,
48+
pub hash: AggregationJobRequestHash,
49+
}
50+
51+
impl HashedAggregationJobReq {
52+
#[cfg(any(test, feature = "test-utils"))]
53+
pub fn from_aggregation_req(version: DapVersion, request: AggregationJobInitReq) -> Self {
54+
let mut buf = Vec::new();
55+
request.encode_with_param(&version, &mut buf).unwrap();
56+
Self {
57+
request,
58+
hash: AggregationJobRequestHash::hash(&buf),
59+
}
60+
}
61+
}
62+
63+
impl ParameterizedEncode<DapVersion> for HashedAggregationJobReq {
64+
fn encode_with_param(
65+
&self,
66+
encoding_parameter: &DapVersion,
67+
bytes: &mut Vec<u8>,
68+
) -> Result<(), prio::codec::CodecError> {
69+
self.request.encode_with_param(encoding_parameter, bytes)
70+
}
71+
}
72+
73+
impl ParameterizedDecode<DapVersion> for HashedAggregationJobReq {
74+
fn decode_with_param(
75+
decoding_parameter: &DapVersion,
76+
bytes: &mut std::io::Cursor<&[u8]>,
77+
) -> Result<Self, prio::codec::CodecError> {
78+
let start = usize::try_from(bytes.position())
79+
.map_err(|e| prio::codec::CodecError::Other(Box::new(e)))?;
80+
let request = AggregationJobInitReq::decode_with_param(decoding_parameter, bytes)?;
81+
let end = usize::try_from(bytes.position())
82+
.map_err(|e| prio::codec::CodecError::Other(Box::new(e)))?;
83+
84+
Ok(Self {
85+
request,
86+
hash: AggregationJobRequestHash::hash(&bytes.get_ref()[start..end]),
87+
})
88+
}
89+
}
90+
2791
impl_req_body! {
28-
// body type | id type
29-
// --------------------- | ----------------
30-
Report | ()
31-
AggregationJobInitReq | AggregationJobId
32-
AggregateShareReq | ()
33-
CollectionReq | CollectionJobId
34-
CollectionPollReq | CollectionJobId
35-
() | ()
92+
// body type | id type
93+
// --------------------| ----------------
94+
Report | ()
95+
AggregationJobInitReq | AggregationJobId
96+
HashedAggregationJobReq | AggregationJobId
97+
AggregateShareReq | ()
98+
CollectionReq | CollectionJobId
99+
CollectionPollReq | CollectionJobId
100+
() | ()
36101
}
37102

38103
/// Fields common to all DAP requests.
@@ -74,6 +139,20 @@ pub struct DapRequest<B: RequestBody> {
74139
pub payload: B,
75140
}
76141

142+
impl<B: RequestBody> DapRequest<B> {
143+
pub fn map<F, O>(self, mapper: F) -> DapRequest<O>
144+
where
145+
F: FnOnce(B) -> O,
146+
O: RequestBody<ResourceId = B::ResourceId>,
147+
{
148+
DapRequest {
149+
meta: self.meta,
150+
resource_id: self.resource_id,
151+
payload: mapper(self.payload),
152+
}
153+
}
154+
}
155+
77156
impl<B: RequestBody> AsRef<DapRequestMeta> for DapRequest<B> {
78157
fn as_ref(&self) -> &DapRequestMeta {
79158
&self.meta

0 commit comments

Comments
 (0)