Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/daphne-server/src/storage_proxy_connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::fmt::Debug;

use axum::http::StatusCode;
use daphne_service_utils::{
capnproto_payload::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _},
capnproto::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _},
durable_requests::{bindings::DurableMethod, DurableRequest, ObjectIdFrom, DO_PATH_PREFIX},
};
use serde::de::DeserializeOwned;
Expand Down
1 change: 1 addition & 0 deletions crates/daphne-service-utils/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
fn main() {
#[cfg(feature = "durable_requests")]
::capnpc::CompilerCommand::new()
.file("./src/capnproto/base.capnp")
.file("./src/durable_requests/durable_request.capnp")
.run()
.expect("compiling schema");
Expand Down
36 changes: 36 additions & 0 deletions crates/daphne-service-utils/src/capnproto/base.capnp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

@0xba869f168ff63e77;

enum DapVersion @0xb5b2c8705a8b22d5 {
draft09 @0;
draftLatest @1;
}

# [u8; 32]
struct U8L32 @0x9e42cda292792294 {
fst @0 :UInt64;
snd @1 :UInt64;
thr @2 :UInt64;
frh @3 :UInt64;
}

# [u8; 16]
struct U8L16 @0x9e3f65b13f71cfcb {
fst @0 :UInt64;
snd @1 :UInt64;
}

struct PartialBatchSelector {
union {
timeInterval @0 :Void;
leaderSelectedByBatchId @1 :BatchId;
}
}

using ReportId = U8L16;
using BatchId = U8L32;
using TaskId = U8L32;
using AggregationJobId = U8L16;
using Time = UInt64;
231 changes: 231 additions & 0 deletions crates/daphne-service-utils/src/capnproto/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use crate::base_capnp::{self, partial_batch_selector, u8_l16, u8_l32};
use capnp::struct_list;
use capnp::traits::{FromPointerBuilder, FromPointerReader};
use daphne::{
messages::{AggregationJobId, BatchId, PartialBatchSelector, ReportId, TaskId},
DapVersion,
};

pub trait CapnprotoPayloadEncode {
type Builder<'a>: FromPointerBuilder<'a>;

fn encode_to_builder(&self, builder: Self::Builder<'_>);
}

pub trait CapnprotoPayloadEncodeExt {
fn encode_to_bytes(&self) -> Vec<u8>;
}

pub trait CapnprotoPayloadDecode {
type Reader<'a>: FromPointerReader<'a>;

fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
where
Self: Sized;
}

pub trait CapnprotoPayloadDecodeExt {
fn decode_from_bytes(bytes: &[u8]) -> capnp::Result<Self>
where
Self: Sized;
}

impl<T> CapnprotoPayloadEncodeExt for T
where
T: CapnprotoPayloadEncode,
{
fn encode_to_bytes(&self) -> Vec<u8> {
let mut message = capnp::message::Builder::new_default();
self.encode_to_builder(message.init_root::<T::Builder<'_>>());
let mut buf = Vec::new();
capnp::serialize_packed::write_message(&mut buf, &message).expect("infalible");
buf
}
}

impl<T> CapnprotoPayloadDecodeExt for T
where
T: CapnprotoPayloadDecode,
{
fn decode_from_bytes(bytes: &[u8]) -> capnp::Result<Self>
where
Self: Sized,
{
let mut cursor = std::io::Cursor::new(bytes);
let reader = capnp::serialize_packed::read_message(
&mut cursor,
capnp::message::ReaderOptions::new(),
)?;

let reader = reader.get_root::<T::Reader<'_>>()?;
T::decode_from_reader(reader)
}
}

impl<T> CapnprotoPayloadEncode for &'_ T
where
T: CapnprotoPayloadEncode,
{
type Builder<'a> = T::Builder<'a>;

fn encode_to_builder(&self, builder: Self::Builder<'_>) {
T::encode_to_builder(self, builder);
}
}

impl From<base_capnp::DapVersion> for DapVersion {
fn from(val: base_capnp::DapVersion) -> Self {
match val {
base_capnp::DapVersion::Draft09 => DapVersion::Draft09,
base_capnp::DapVersion::DraftLatest => DapVersion::Latest,
}
}
}

impl From<DapVersion> for base_capnp::DapVersion {
fn from(value: DapVersion) -> Self {
match value {
DapVersion::Draft09 => base_capnp::DapVersion::Draft09,
DapVersion::Latest => base_capnp::DapVersion::DraftLatest,
}
}
}

impl CapnprotoPayloadEncode for [u8; 32] {
type Builder<'a> = u8_l32::Builder<'a>;

fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
builder.set_fst(u64::from_le_bytes(self[0..8].try_into().unwrap()));
builder.set_snd(u64::from_le_bytes(self[8..16].try_into().unwrap()));
builder.set_thr(u64::from_le_bytes(self[16..24].try_into().unwrap()));
builder.set_frh(u64::from_le_bytes(self[24..32].try_into().unwrap()));
}
}

impl CapnprotoPayloadDecode for [u8; 32] {
type Reader<'a> = u8_l32::Reader<'a>;

fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
where
Self: Sized,
{
let mut array = [0; 32];
array[0..8].copy_from_slice(&reader.get_fst().to_le_bytes());
array[8..16].copy_from_slice(&reader.get_snd().to_le_bytes());
array[16..24].copy_from_slice(&reader.get_thr().to_le_bytes());
array[24..32].copy_from_slice(&reader.get_frh().to_le_bytes());
Ok(array)
}
}

impl CapnprotoPayloadEncode for [u8; 16] {
type Builder<'a> = u8_l16::Builder<'a>;

fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
builder.set_fst(u64::from_le_bytes(self[0..8].try_into().unwrap()));
builder.set_snd(u64::from_le_bytes(self[8..16].try_into().unwrap()));
}
}

impl CapnprotoPayloadDecode for [u8; 16] {
type Reader<'a> = u8_l16::Reader<'a>;

fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
where
Self: Sized,
{
let mut array = [0; 16];
array[0..8].copy_from_slice(&reader.get_fst().to_le_bytes());
array[8..16].copy_from_slice(&reader.get_snd().to_le_bytes());
Ok(array)
}
}

macro_rules! capnp_encode_ids {
($($id:ident => $inner:ident),*$(,)?) => {
$(
impl CapnprotoPayloadEncode for $id {
type Builder<'a> = $inner::Builder<'a>;

fn encode_to_builder(&self, builder: Self::Builder<'_>) {
self.0.encode_to_builder(builder)
}
}

impl CapnprotoPayloadDecode for $id {
type Reader<'a> = $inner::Reader<'a>;

fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
where
Self: Sized,
{
<_>::decode_from_reader(reader).map(Self)
}
}
)*
};
}

capnp_encode_ids! {
TaskId => u8_l32,
ReportId => u8_l16,
BatchId => u8_l32,
AggregationJobId => u8_l16,
}

impl CapnprotoPayloadEncode for PartialBatchSelector {
type Builder<'a> = partial_batch_selector::Builder<'a>;

fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
match self {
PartialBatchSelector::TimeInterval => builder.set_time_interval(()),
PartialBatchSelector::LeaderSelectedByBatchId { batch_id } => {
batch_id.encode_to_builder(builder.init_leader_selected_by_batch_id());
}
}
}
}

impl CapnprotoPayloadDecode for PartialBatchSelector {
type Reader<'a> = partial_batch_selector::Reader<'a>;

fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self> {
match reader.which()? {
partial_batch_selector::Which::TimeInterval(()) => Ok(Self::TimeInterval),
partial_batch_selector::Which::LeaderSelectedByBatchId(reader) => {
Ok(Self::LeaderSelectedByBatchId {
batch_id: <_>::decode_from_reader(reader?)?,
})
}
}
}
}

pub fn encode_list<I, O>(list: I, mut builder: struct_list::Builder<'_, O>)
where
I: IntoIterator<Item: CapnprotoPayloadEncode>,
O: for<'b> capnp::traits::OwnedStruct<
Builder<'b> = <I::Item as CapnprotoPayloadEncode>::Builder<'b>,
>,
{
for (i, item) in list.into_iter().enumerate() {
item.encode_to_builder(builder.reborrow().get(i.try_into().unwrap()));
}
}

pub fn decode_list<T, O, C>(reader: struct_list::Reader<'_, O>) -> capnp::Result<C>
where
T: CapnprotoPayloadDecode,
C: FromIterator<T>,
O: for<'b> capnp::traits::OwnedStruct<Reader<'b> = T::Reader<'b>>,
{
reader.into_iter().map(T::decode_from_reader).collect()
}

pub fn usize_to_capnp_len(u: usize) -> u32 {
u.try_into()
.expect("capnp can't encode more that u32::MAX of something")
}
60 changes: 0 additions & 60 deletions crates/daphne-service-utils/src/capnproto_payload.rs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use daphne::{
use serde::{Deserialize, Serialize};

use crate::{
capnproto_payload::{CapnprotoPayloadDecode, CapnprotoPayloadEncode},
capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadEncode},
durable_request_capnp::{aggregate_store_merge_req, dap_aggregate_share},
durable_requests::ObjectIdFrom,
};
Expand Down Expand Up @@ -284,9 +284,7 @@ mod test {
};
use rand::{thread_rng, Rng};

use crate::capnproto_payload::{
CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _,
};
use crate::capnproto::{CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _};

use super::*;

Expand Down
11 changes: 10 additions & 1 deletion crates/daphne-service-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@

pub mod bearer_token;
#[cfg(feature = "durable_requests")]
pub mod capnproto_payload;
pub mod capnproto;
#[cfg(feature = "durable_requests")]
pub mod durable_requests;
pub mod http_headers;
#[cfg(feature = "test-utils")]
pub mod test_route_types;

// the generated code expects this module to be defined at the root of the library.
#[cfg(feature = "durable_requests")]
#[doc(hidden)]
pub mod base_capnp {
#![allow(dead_code)]
#![allow(clippy::pedantic)]
#![allow(clippy::needless_lifetimes)]
include!(concat!(env!("OUT_DIR"), "/src/capnproto/base_capnp.rs"));
}

#[cfg(feature = "durable_requests")]
mod durable_request_capnp {
#![allow(dead_code)]
Expand Down
Loading
Loading