Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/daphne-server/src/storage_proxy_connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl<'d, B: DurableMethod + Debug, P: AsRef<[u8]>> RequestBuilder<'d, B, P> {

impl<'d, B: DurableMethod> RequestBuilder<'d, B, [u8; 0]> {
pub fn encode<T: CapnprotoPayloadEncode>(self, payload: &T) -> RequestBuilder<'d, B, Vec<u8>> {
self.with_body(payload.encode_to_bytes().unwrap())
self.with_body(payload.encode_to_bytes())
}

pub fn with_body<T: AsRef<[u8]>>(self, payload: T) -> RequestBuilder<'d, B, T> {
Expand Down
24 changes: 15 additions & 9 deletions crates/daphne-service-utils/src/capnproto_payload.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use capnp::traits::{FromPointerBuilder, FromPointerReader};

pub trait CapnprotoPayloadEncode {
fn encode_to_builder(&self) -> capnp::message::Builder<capnp::message::HeapAllocator>;
type Builder<'a>: FromPointerBuilder<'a>;

fn encode_to_builder(&self, builder: Self::Builder<'_>);
}

pub trait CapnprotoPayloadEncodeExt {
fn encode_to_bytes(&self) -> capnp::Result<Vec<u8>>;
fn encode_to_bytes(&self) -> Vec<u8>;
}

pub trait CapnprotoPayloadDecode {
fn decode_from_reader(
reader: capnp::message::Reader<capnp::serialize::OwnedSegments>,
) -> capnp::Result<Self>
type Reader<'a>: FromPointerReader<'a>;

fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
where
Self: Sized;
}
Expand All @@ -27,11 +31,12 @@ impl<T> CapnprotoPayloadEncodeExt for T
where
T: CapnprotoPayloadEncode,
{
fn encode_to_bytes(&self) -> capnp::Result<Vec<u8>> {
fn encode_to_bytes(&self) -> Vec<u8> {
let mut message = capnp::message::Builder::new_default();
self.encode_to_builder(message.init_root::<T::Builder<'_>>());
let mut buf = Vec::new();
let message = self.encode_to_builder();
capnp::serialize_packed::write_message(&mut buf, &message)?;
Ok(buf)
capnp::serialize_packed::write_message(&mut buf, &message).expect("infalible");
buf
}
}

Expand All @@ -49,6 +54,7 @@ where
capnp::message::ReaderOptions::new(),
)?;

let reader = reader.get_root::<T::Reader<'_>>()?;
T::decode_from_reader(reader)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ pub struct AggregateStoreMergeOptions {
}

impl CapnprotoPayloadEncode for AggregateStoreMergeReq {
fn encode_to_builder(&self) -> capnp::message::Builder<capnp::message::HeapAllocator> {
type Builder<'a> = aggregate_store_merge_req::Builder<'a>;

fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
let Self {
contained_reports,
agg_share_delta,
options,
} = self;
let mut message = capnp::message::Builder::new_default();
let mut request = message.init_root::<aggregate_store_merge_req::Builder>();
{
let mut contained_reports = request.reborrow().init_contained_reports(
let mut contained_reports = builder.reborrow().init_contained_reports(
contained_reports
.len()
.try_into()
Expand All @@ -94,7 +94,7 @@ impl CapnprotoPayloadEncode for AggregateStoreMergeReq {
}
}
{
let mut agg_share_delta_packet = request.reborrow().init_agg_share_delta();
let mut agg_share_delta_packet = builder.reborrow().init_agg_share_delta();
agg_share_delta_packet.set_report_count(agg_share_delta.report_count);
agg_share_delta_packet.set_min_time(agg_share_delta.min_time);
agg_share_delta_packet.set_max_time(agg_share_delta.max_time);
Expand Down Expand Up @@ -157,20 +157,18 @@ impl CapnprotoPayloadEncode for AggregateStoreMergeReq {
let AggregateStoreMergeOptions {
skip_replay_protection,
} = options;
let mut options_packet = request.init_options();
let mut options_packet = builder.init_options();
options_packet.set_skip_replay_protection(*skip_replay_protection);
}
message
}
}

impl CapnprotoPayloadDecode for AggregateStoreMergeReq {
fn decode_from_reader(
reader: capnp::message::Reader<capnp::serialize::OwnedSegments>,
) -> capnp::Result<Self> {
let request = reader.get_root::<aggregate_store_merge_req::Reader>()?;
type Reader<'a> = aggregate_store_merge_req::Reader<'a>;

fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self> {
let agg_share_delta = {
let agg_share_delta = request.get_agg_share_delta()?;
let agg_share_delta = reader.get_agg_share_delta()?;
let data = {
macro_rules! make_decode {
($func_name:ident, $agg_share_type:ident, $field_trait:ident, $field_error:ident) => {
Expand Down Expand Up @@ -238,8 +236,7 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq {
}
};
let contained_reports = {
request
.reborrow()
reader
.get_contained_reports()?
.into_iter()
.map(|report| {
Expand All @@ -257,7 +254,7 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq {
contained_reports,
agg_share_delta,
options: AggregateStoreMergeOptions {
skip_replay_protection: request.get_options()?.get_skip_replay_protection(),
skip_replay_protection: reader.get_options()?.get_skip_replay_protection(),
},
})
}
Expand Down Expand Up @@ -352,8 +349,7 @@ mod test {
},
};
let other =
AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes().unwrap())
.unwrap();
AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes()).unwrap();
assert_eq!(this, other);
}
}
Expand Down Expand Up @@ -411,8 +407,7 @@ mod test {
},
};
let other =
AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes().unwrap())
.unwrap();
AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes()).unwrap();
assert_eq!(this, other);
}
}
Expand Down
Loading