Skip to content

Commit cbbd376

Browse files
committed
Create base capnproto lib for basic dap types
1 parent 964bb98 commit cbbd376

File tree

9 files changed

+286
-68
lines changed

9 files changed

+286
-68
lines changed

crates/daphne-server/src/storage_proxy_connection/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::fmt::Debug;
77

88
use axum::http::StatusCode;
99
use daphne_service_utils::{
10-
capnproto_payload::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _},
10+
capnproto::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _},
1111
durable_requests::{bindings::DurableMethod, DurableRequest, ObjectIdFrom, DO_PATH_PREFIX},
1212
};
1313
use serde::de::DeserializeOwned;

crates/daphne-service-utils/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
fn main() {
55
#[cfg(feature = "durable_requests")]
66
::capnpc::CompilerCommand::new()
7+
.file("./src/capnproto/base.capnp")
78
.file("./src/durable_requests/durable_request.capnp")
89
.run()
910
.expect("compiling schema");
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
@0xba869f168ff63e77;
5+
6+
enum DapVersion @0xb5b2c8705a8b22d5 {
7+
draft09 @0;
8+
draftLatest @1;
9+
}
10+
11+
# [u8; 32]
12+
struct U8L32 @0x9e42cda292792294 {
13+
fst @0 :UInt64;
14+
snd @1 :UInt64;
15+
thr @2 :UInt64;
16+
frh @3 :UInt64;
17+
}
18+
19+
# [u8; 16]
20+
struct U8L16 @0x9e3f65b13f71cfcb {
21+
fst @0 :UInt64;
22+
snd @1 :UInt64;
23+
}
24+
25+
struct PartialBatchSelector {
26+
union {
27+
timeInterval @0 :Void;
28+
leaderSelectedByBatchId @1 :BatchId;
29+
}
30+
}
31+
32+
using ReportId = U8L16;
33+
using BatchId = U8L32;
34+
using TaskId = U8L32;
35+
using AggregationJobId = U8L16;
36+
using Time = UInt64;
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
use crate::base_capnp::{self, partial_batch_selector, u8_l16, u8_l32};
5+
use capnp::struct_list;
6+
use capnp::traits::{FromPointerBuilder, FromPointerReader};
7+
use daphne::{
8+
messages::{AggregationJobId, BatchId, PartialBatchSelector, ReportId, TaskId},
9+
DapVersion,
10+
};
11+
12+
pub trait CapnprotoPayloadEncode {
13+
type Builder<'a>: FromPointerBuilder<'a>;
14+
15+
fn encode_to_builder(&self, builder: Self::Builder<'_>);
16+
}
17+
18+
pub trait CapnprotoPayloadEncodeExt {
19+
fn encode_to_bytes(&self) -> Vec<u8>;
20+
}
21+
22+
pub trait CapnprotoPayloadDecode {
23+
type Reader<'a>: FromPointerReader<'a>;
24+
25+
fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
26+
where
27+
Self: Sized;
28+
}
29+
30+
pub trait CapnprotoPayloadDecodeExt {
31+
fn decode_from_bytes(bytes: &[u8]) -> capnp::Result<Self>
32+
where
33+
Self: Sized;
34+
}
35+
36+
impl<T> CapnprotoPayloadEncodeExt for T
37+
where
38+
T: CapnprotoPayloadEncode,
39+
{
40+
fn encode_to_bytes(&self) -> Vec<u8> {
41+
let mut message = capnp::message::Builder::new_default();
42+
self.encode_to_builder(message.init_root::<T::Builder<'_>>());
43+
let mut buf = Vec::new();
44+
capnp::serialize_packed::write_message(&mut buf, &message).expect("infalible");
45+
buf
46+
}
47+
}
48+
49+
impl<T> CapnprotoPayloadDecodeExt for T
50+
where
51+
T: CapnprotoPayloadDecode,
52+
{
53+
fn decode_from_bytes(bytes: &[u8]) -> capnp::Result<Self>
54+
where
55+
Self: Sized,
56+
{
57+
let mut cursor = std::io::Cursor::new(bytes);
58+
let reader = capnp::serialize_packed::read_message(
59+
&mut cursor,
60+
capnp::message::ReaderOptions::new(),
61+
)?;
62+
63+
let reader = reader.get_root::<T::Reader<'_>>()?;
64+
T::decode_from_reader(reader)
65+
}
66+
}
67+
68+
impl<T> CapnprotoPayloadEncode for &'_ T
69+
where
70+
T: CapnprotoPayloadEncode,
71+
{
72+
type Builder<'a> = T::Builder<'a>;
73+
74+
fn encode_to_builder(&self, builder: Self::Builder<'_>) {
75+
T::encode_to_builder(self, builder);
76+
}
77+
}
78+
79+
impl From<base_capnp::DapVersion> for DapVersion {
80+
fn from(val: base_capnp::DapVersion) -> Self {
81+
match val {
82+
base_capnp::DapVersion::Draft09 => DapVersion::Draft09,
83+
base_capnp::DapVersion::DraftLatest => DapVersion::Latest,
84+
}
85+
}
86+
}
87+
88+
impl From<DapVersion> for base_capnp::DapVersion {
89+
fn from(value: DapVersion) -> Self {
90+
match value {
91+
DapVersion::Draft09 => base_capnp::DapVersion::Draft09,
92+
DapVersion::Latest => base_capnp::DapVersion::DraftLatest,
93+
}
94+
}
95+
}
96+
97+
impl CapnprotoPayloadEncode for [u8; 32] {
98+
type Builder<'a> = u8_l32::Builder<'a>;
99+
100+
fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
101+
builder.set_fst(u64::from_le_bytes(self[0..8].try_into().unwrap()));
102+
builder.set_snd(u64::from_le_bytes(self[8..16].try_into().unwrap()));
103+
builder.set_thr(u64::from_le_bytes(self[16..24].try_into().unwrap()));
104+
builder.set_frh(u64::from_le_bytes(self[24..32].try_into().unwrap()));
105+
}
106+
}
107+
108+
impl CapnprotoPayloadDecode for [u8; 32] {
109+
type Reader<'a> = u8_l32::Reader<'a>;
110+
111+
fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
112+
where
113+
Self: Sized,
114+
{
115+
let mut array = [0; 32];
116+
array[0..8].copy_from_slice(&reader.get_fst().to_le_bytes());
117+
array[8..16].copy_from_slice(&reader.get_snd().to_le_bytes());
118+
array[16..24].copy_from_slice(&reader.get_thr().to_le_bytes());
119+
array[24..32].copy_from_slice(&reader.get_frh().to_le_bytes());
120+
Ok(array)
121+
}
122+
}
123+
124+
impl CapnprotoPayloadEncode for [u8; 16] {
125+
type Builder<'a> = u8_l16::Builder<'a>;
126+
127+
fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
128+
builder.set_fst(u64::from_le_bytes(self[0..8].try_into().unwrap()));
129+
builder.set_snd(u64::from_le_bytes(self[8..16].try_into().unwrap()));
130+
}
131+
}
132+
133+
impl CapnprotoPayloadDecode for [u8; 16] {
134+
type Reader<'a> = u8_l16::Reader<'a>;
135+
136+
fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
137+
where
138+
Self: Sized,
139+
{
140+
let mut array = [0; 16];
141+
array[0..8].copy_from_slice(&reader.get_fst().to_le_bytes());
142+
array[8..16].copy_from_slice(&reader.get_snd().to_le_bytes());
143+
Ok(array)
144+
}
145+
}
146+
147+
macro_rules! capnp_encode_ids {
148+
($($id:ident => $inner:ident),*$(,)?) => {
149+
$(
150+
impl CapnprotoPayloadEncode for $id {
151+
type Builder<'a> = $inner::Builder<'a>;
152+
153+
fn encode_to_builder(&self, builder: Self::Builder<'_>) {
154+
self.0.encode_to_builder(builder)
155+
}
156+
}
157+
158+
impl CapnprotoPayloadDecode for $id {
159+
type Reader<'a> = $inner::Reader<'a>;
160+
161+
fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
162+
where
163+
Self: Sized,
164+
{
165+
<_>::decode_from_reader(reader).map(Self)
166+
}
167+
}
168+
)*
169+
};
170+
}
171+
172+
capnp_encode_ids! {
173+
TaskId => u8_l32,
174+
ReportId => u8_l16,
175+
BatchId => u8_l32,
176+
AggregationJobId => u8_l16,
177+
}
178+
179+
impl CapnprotoPayloadEncode for PartialBatchSelector {
180+
type Builder<'a> = partial_batch_selector::Builder<'a>;
181+
182+
fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
183+
match self {
184+
PartialBatchSelector::TimeInterval => builder.set_time_interval(()),
185+
PartialBatchSelector::LeaderSelectedByBatchId { batch_id } => {
186+
batch_id.encode_to_builder(builder.init_leader_selected_by_batch_id());
187+
}
188+
}
189+
}
190+
}
191+
192+
impl CapnprotoPayloadDecode for PartialBatchSelector {
193+
type Reader<'a> = partial_batch_selector::Reader<'a>;
194+
195+
fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self> {
196+
match reader.which()? {
197+
partial_batch_selector::Which::TimeInterval(()) => Ok(Self::TimeInterval),
198+
partial_batch_selector::Which::LeaderSelectedByBatchId(reader) => {
199+
Ok(Self::LeaderSelectedByBatchId {
200+
batch_id: <_>::decode_from_reader(reader?)?,
201+
})
202+
}
203+
}
204+
}
205+
}
206+
207+
pub fn encode_list<I, O>(list: I, mut builder: struct_list::Builder<'_, O>)
208+
where
209+
I: IntoIterator<Item: CapnprotoPayloadEncode>,
210+
O: for<'b> capnp::traits::OwnedStruct<
211+
Builder<'b> = <I::Item as CapnprotoPayloadEncode>::Builder<'b>,
212+
>,
213+
{
214+
for (i, item) in list.into_iter().enumerate() {
215+
item.encode_to_builder(builder.reborrow().get(i.try_into().unwrap()));
216+
}
217+
}
218+
219+
pub fn decode_list<T, O, C>(reader: struct_list::Reader<'_, O>) -> capnp::Result<C>
220+
where
221+
T: CapnprotoPayloadDecode,
222+
C: FromIterator<T>,
223+
O: for<'b> capnp::traits::OwnedStruct<Reader<'b> = T::Reader<'b>>,
224+
{
225+
reader.into_iter().map(T::decode_from_reader).collect()
226+
}
227+
228+
pub fn usize_to_capnp_len(u: usize) -> u32 {
229+
u.try_into()
230+
.expect("capnp can't encode more that u32::MAX of something")
231+
}

crates/daphne-service-utils/src/capnproto_payload.rs

Lines changed: 0 additions & 60 deletions
This file was deleted.

crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use daphne::{
1111
use serde::{Deserialize, Serialize};
1212

1313
use crate::{
14-
capnproto_payload::{CapnprotoPayloadDecode, CapnprotoPayloadEncode},
14+
capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadEncode},
1515
durable_request_capnp::{aggregate_store_merge_req, dap_aggregate_share},
1616
durable_requests::ObjectIdFrom,
1717
};
@@ -284,9 +284,7 @@ mod test {
284284
};
285285
use rand::{thread_rng, Rng};
286286

287-
use crate::capnproto_payload::{
288-
CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _,
289-
};
287+
use crate::capnproto::{CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _};
290288

291289
use super::*;
292290

crates/daphne-service-utils/src/lib.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,26 @@
55

66
pub mod bearer_token;
77
#[cfg(feature = "durable_requests")]
8-
pub mod capnproto_payload;
8+
pub mod capnproto;
99
#[cfg(feature = "durable_requests")]
1010
pub mod durable_requests;
1111
pub mod http_headers;
1212
#[cfg(feature = "test-utils")]
1313
pub mod test_route_types;
1414

1515
// the generated code expects this module to be defined at the root of the library.
16+
#[cfg(feature = "durable_requests")]
17+
#[doc(hidden)]
18+
pub mod base_capnp {
19+
#![allow(dead_code)]
20+
#![allow(clippy::pedantic)]
21+
#![allow(clippy::needless_lifetimes)]
22+
include!(concat!(
23+
env!("OUT_DIR"),
24+
"/src/capnproto_payload/base_capnp.rs"
25+
));
26+
}
27+
1628
#[cfg(feature = "durable_requests")]
1729
mod durable_request_capnp {
1830
#![allow(dead_code)]

0 commit comments

Comments
 (0)