diff --git a/crates/daphne-server/src/storage_proxy_connection/mod.rs b/crates/daphne-server/src/storage_proxy_connection/mod.rs index 4e5cb587f..555d0084e 100644 --- a/crates/daphne-server/src/storage_proxy_connection/mod.rs +++ b/crates/daphne-server/src/storage_proxy_connection/mod.rs @@ -96,7 +96,7 @@ impl<'d, B: DurableMethod + Debug, P: AsRef<[u8]>> RequestBuilder<'d, B, P> { impl<'d, B: DurableMethod> RequestBuilder<'d, B, [u8; 0]> { pub fn encode(self, payload: &T) -> RequestBuilder<'d, B, Vec> { - self.with_body(payload.encode_to_bytes().unwrap()) + self.with_body(payload.encode_to_bytes()) } pub fn with_body>(self, payload: T) -> RequestBuilder<'d, B, T> { diff --git a/crates/daphne-service-utils/src/capnproto_payload.rs b/crates/daphne-service-utils/src/capnproto_payload.rs index 01dbbbe53..e3c2b6aa1 100644 --- a/crates/daphne-service-utils/src/capnproto_payload.rs +++ b/crates/daphne-service-utils/src/capnproto_payload.rs @@ -1,18 +1,22 @@ // Copyright (c) 2024 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause +use capnp::traits::{FromPointerBuilder, FromPointerReader}; + pub trait CapnprotoPayloadEncode { - fn encode_to_builder(&self) -> capnp::message::Builder; + type Builder<'a>: FromPointerBuilder<'a>; + + fn encode_to_builder(&self, builder: Self::Builder<'_>); } pub trait CapnprotoPayloadEncodeExt { - fn encode_to_bytes(&self) -> capnp::Result>; + fn encode_to_bytes(&self) -> Vec; } pub trait CapnprotoPayloadDecode { - fn decode_from_reader( - reader: capnp::message::Reader, - ) -> capnp::Result + type Reader<'a>: FromPointerReader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result where Self: Sized; } @@ -27,11 +31,12 @@ impl CapnprotoPayloadEncodeExt for T where T: CapnprotoPayloadEncode, { - fn encode_to_bytes(&self) -> capnp::Result> { + fn encode_to_bytes(&self) -> Vec { + let mut message = capnp::message::Builder::new_default(); + self.encode_to_builder(message.init_root::>()); let mut buf = Vec::new(); - let message = self.encode_to_builder(); - capnp::serialize_packed::write_message(&mut buf, &message)?; - Ok(buf) + capnp::serialize_packed::write_message(&mut buf, &message).expect("infalible"); + buf } } @@ -49,6 +54,7 @@ where capnp::message::ReaderOptions::new(), )?; + let reader = reader.get_root::>()?; T::decode_from_reader(reader) } } diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs index 51bf41fea..eb8c4cca8 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs @@ -69,16 +69,16 @@ pub struct AggregateStoreMergeOptions { } impl CapnprotoPayloadEncode for AggregateStoreMergeReq { - fn encode_to_builder(&self) -> capnp::message::Builder { + type Builder<'a> = aggregate_store_merge_req::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { let Self { contained_reports, agg_share_delta, options, } = self; - let mut message = capnp::message::Builder::new_default(); - let mut request = message.init_root::(); { - let mut contained_reports = request.reborrow().init_contained_reports( + let mut contained_reports = builder.reborrow().init_contained_reports( contained_reports .len() .try_into() @@ -94,7 +94,7 @@ impl CapnprotoPayloadEncode for AggregateStoreMergeReq { } } { - let mut agg_share_delta_packet = request.reborrow().init_agg_share_delta(); + let mut agg_share_delta_packet = builder.reborrow().init_agg_share_delta(); agg_share_delta_packet.set_report_count(agg_share_delta.report_count); agg_share_delta_packet.set_min_time(agg_share_delta.min_time); agg_share_delta_packet.set_max_time(agg_share_delta.max_time); @@ -157,20 +157,18 @@ impl CapnprotoPayloadEncode for AggregateStoreMergeReq { let AggregateStoreMergeOptions { skip_replay_protection, } = options; - let mut options_packet = request.init_options(); + let mut options_packet = builder.init_options(); options_packet.set_skip_replay_protection(*skip_replay_protection); } - message } } impl CapnprotoPayloadDecode for AggregateStoreMergeReq { - fn decode_from_reader( - reader: capnp::message::Reader, - ) -> capnp::Result { - let request = reader.get_root::()?; + type Reader<'a> = aggregate_store_merge_req::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result { let agg_share_delta = { - let agg_share_delta = request.get_agg_share_delta()?; + let agg_share_delta = reader.get_agg_share_delta()?; let data = { macro_rules! make_decode { ($func_name:ident, $agg_share_type:ident, $field_trait:ident, $field_error:ident) => { @@ -238,8 +236,7 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq { } }; let contained_reports = { - request - .reborrow() + reader .get_contained_reports()? .into_iter() .map(|report| { @@ -257,7 +254,7 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq { contained_reports, agg_share_delta, options: AggregateStoreMergeOptions { - skip_replay_protection: request.get_options()?.get_skip_replay_protection(), + skip_replay_protection: reader.get_options()?.get_skip_replay_protection(), }, }) } @@ -352,8 +349,7 @@ mod test { }, }; let other = - AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes().unwrap()) - .unwrap(); + AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes()).unwrap(); assert_eq!(this, other); } } @@ -411,8 +407,7 @@ mod test { }, }; let other = - AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes().unwrap()) - .unwrap(); + AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes()).unwrap(); assert_eq!(this, other); } }