Skip to content

Commit ed04ead

Browse files
committed
SQUASH Add public extensions to capnp schema
1 parent 8815e30 commit ed04ead

File tree

3 files changed

+108
-21
lines changed

3 files changed

+108
-21
lines changed

crates/daphne-service-utils/src/compute_offload/compute_offload.capnp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,24 @@ struct PartialDapTaskConfig @0xdcc9bf18fc62d406 {
5353
vdafVerifyKey @4 :VdafVerifyKey;
5454
}
5555

56+
struct PublicExtensionsList @0x8b3c98c0ddd0043e {
57+
58+
union {
59+
# Each extension is encoded according to the DAP spec in
60+
# tag-length-value form.
61+
list @0 :List(Data);
62+
63+
# draft09 compatibility: Previously DAP had no extensions in the
64+
# report.
65+
none @1 :Void;
66+
}
67+
}
68+
5669
struct ReportMetadata @0xefba178ad4584bc4 {
5770

58-
id @0 :Base.ReportId;
59-
time @1 :Base.Time;
71+
id @0 :Base.ReportId;
72+
time @1 :Base.Time;
73+
publicExtensions @2 :PublicExtensionsList;
6074
}
6175

6276
struct PrepareInit @0x8192568cb3d03f59 {

crates/daphne-service-utils/src/compute_offload/mod.rs

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@ use crate::{
1010
hpke_receiver_config::{self, hpke_config},
1111
initialize_reports,
1212
initialized_reports::{self, initialized_report},
13-
partial_dap_task_config, prepare_init, report_metadata, time_range,
13+
partial_dap_task_config, prepare_init, public_extensions_list, report_metadata, time_range,
1414
},
1515
};
1616
use daphne::{
1717
constants::DapAggregatorRole,
1818
hpke::{HpkeConfig, HpkeReceiverConfig},
19-
messages::{self, HpkeCiphertext, PrepareInit, ReportMetadata, ReportShare, TaskId},
19+
messages::{self, Extension, HpkeCiphertext, PrepareInit, ReportMetadata, ReportShare, TaskId},
2020
vdaf::{VdafConfig, VdafPrepShare, VdafPrepState},
2121
InitializedReport, PartialDapTaskConfigForReportInit, WithPeerPrepShare,
2222
};
23-
use prio::codec::{Encode, ParameterizedDecode, ParameterizedEncode};
23+
use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode};
2424
use std::{borrow::Cow, ops::Range};
2525

2626
pub struct InitializeReports<'s> {
@@ -318,9 +318,27 @@ impl CapnprotoPayloadEncode for ReportMetadata {
318318
type Builder<'a> = report_metadata::Builder<'a>;
319319

320320
fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
321-
let Self { id, time } = self;
321+
let Self {
322+
id,
323+
time,
324+
public_extensions,
325+
} = self;
322326
id.encode_to_builder(builder.reborrow().init_id());
323327
builder.set_time(*time);
328+
if let Some(ref extensions) = public_extensions {
329+
let mut e = builder
330+
.init_public_extensions()
331+
.init_list(usize_to_capnp_len(extensions.len()));
332+
for (i, data) in extensions
333+
.iter()
334+
.enumerate()
335+
.map(|(i, ext)| (usize_to_capnp_len(i), ext.get_encoded().unwrap()))
336+
{
337+
e.reborrow().set(i, &data);
338+
}
339+
} else {
340+
builder.init_public_extensions().set_none(());
341+
}
324342
}
325343
}
326344

@@ -331,9 +349,25 @@ impl CapnprotoPayloadDecode for ReportMetadata {
331349
where
332350
Self: Sized,
333351
{
352+
let id = <_>::decode_from_reader(reader.get_id()?)?;
353+
let time = reader.get_time();
354+
let public_extensions = match reader.get_public_extensions()?.which()? {
355+
public_extensions_list::List(list) => Some(
356+
list?
357+
.into_iter()
358+
.map(|data| {
359+
Extension::get_decoded(data?)
360+
.map_err(|e| capnp::Error::failed(e.to_string()))
361+
})
362+
.collect::<Result<Vec<_>, capnp::Error>>()?,
363+
),
364+
public_extensions_list::None(()) => None,
365+
};
366+
334367
Ok(Self {
335-
id: <_>::decode_from_reader(reader.get_id()?)?,
336-
time: reader.get_time(),
368+
id,
369+
time,
370+
public_extensions,
337371
})
338372
}
339373
}
@@ -486,3 +520,45 @@ fn to_capnp<E: ToString>(e: E) -> capnp::Error {
486520
extra: e.to_string(),
487521
}
488522
}
523+
524+
#[cfg(test)]
525+
mod test {
526+
use super::*;
527+
use crate::capnproto::{CapnprotoPayloadDecodeExt, CapnprotoPayloadEncodeExt};
528+
529+
#[test]
530+
fn report_metadata_roundtrip() {
531+
let report_metadata = ReportMetadata {
532+
id: messages::ReportId(rand::random()),
533+
time: rand::random(),
534+
public_extensions: Some(vec![
535+
Extension::Taskprov,
536+
Extension::NotImplemented {
537+
typ: 23,
538+
payload: b"some extension payload".to_vec(),
539+
},
540+
]),
541+
};
542+
543+
assert_eq!(
544+
report_metadata,
545+
ReportMetadata::decode_from_bytes(&report_metadata.encode_to_bytes()).unwrap()
546+
);
547+
}
548+
549+
#[test]
550+
fn report_metadata_roundtrip_draft09() {
551+
let report_metadata = ReportMetadata {
552+
id: messages::ReportId(rand::random()),
553+
time: rand::random(),
554+
// draft09 compatibility: Previously there was no extensions field in the report
555+
// metadata.
556+
public_extensions: None,
557+
};
558+
559+
assert_eq!(
560+
report_metadata,
561+
ReportMetadata::decode_from_bytes(&report_metadata.encode_to_bytes()).unwrap()
562+
);
563+
}
564+
}

crates/daphne/src/messages/mod.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,13 @@ impl Extension {
195195
}
196196
}
197197

198-
impl ParameterizedEncode<DapVersion> for Extension {
199-
fn encode_with_param(
200-
&self,
201-
version: &DapVersion,
202-
bytes: &mut Vec<u8>,
203-
) -> Result<(), CodecError> {
198+
impl Encode for Extension {
199+
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
204200
match self {
205201
Self::Taskprov => {
206202
EXTENSION_TASKPROV.encode(bytes)?;
207-
encode_u16_prefixed(*version, bytes, |_, _| Ok(()))?;
203+
// We've hard coded the version here, but we don't actually use it.
204+
encode_u16_prefixed(DapVersion::Latest, bytes, |_, _| Ok(()))?;
208205
}
209206
Self::NotImplemented { typ, payload } => {
210207
typ.encode(bytes)?;
@@ -215,15 +212,15 @@ impl ParameterizedEncode<DapVersion> for Extension {
215212
}
216213
}
217214

218-
impl ParameterizedDecode<DapVersion> for Extension {
219-
fn decode_with_param(
220-
version: &DapVersion,
221-
bytes: &mut Cursor<&[u8]>,
222-
) -> Result<Self, CodecError> {
215+
impl Decode for Extension {
216+
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
223217
let typ = u16::decode(bytes)?;
224218
match typ {
225219
EXTENSION_TASKPROV => {
226-
decode_u16_prefixed(*version, bytes, |_version, inner, _len| <()>::decode(inner))?;
220+
// We've hard coded the version here, but we don't actually use it.
221+
decode_u16_prefixed(DapVersion::Latest, bytes, |_version, inner, _len| {
222+
<()>::decode(inner)
223+
})?;
227224
Ok(Self::Taskprov)
228225
}
229226
_ => Ok(Self::NotImplemented {

0 commit comments

Comments
 (0)