Skip to content

Commit fe58917

Browse files
committed
Added derive macro for custom type handling, updated tests
1 parent bfbda30 commit fe58917

File tree

6 files changed

+95
-111
lines changed

6 files changed

+95
-111
lines changed

cipherstash-dynamodb-derive/src/decryptable.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pub(crate) fn derive_decryptable(input: DeriveInput) -> Result<TokenStream, syn:
99
.field_attributes(&input)?
1010
.build()?;
1111

12-
let protected_attributes = settings.protected_attributes();
12+
let protected_excluding_handlers = settings.protected_attributes_excluding_handlers();
1313
let plaintext_attributes = settings.plaintext_attributes();
1414

1515
let protected_attributes_cow = settings
@@ -25,7 +25,7 @@ pub(crate) fn derive_decryptable(input: DeriveInput) -> Result<TokenStream, syn:
2525
let skipped_attributes = settings.skipped_attributes();
2626
let ident = settings.ident();
2727

28-
let from_unsealed_impl = protected_attributes
28+
let from_unsealed_impl = protected_excluding_handlers
2929
.iter()
3030
.map(|attr| {
3131
let attr_ident = format_ident!("{attr}");
@@ -47,6 +47,13 @@ pub(crate) fn derive_decryptable(input: DeriveInput) -> Result<TokenStream, syn:
4747
quote! {
4848
#attr_ident: Default::default()
4949
}
50+
}))
51+
.chain(settings.decrypt_handlers().iter().map(|(attr, handler)| {
52+
let attr_ident = format_ident!("{attr}");
53+
54+
quote! {
55+
#attr_ident: #handler(&mut unsealed)?
56+
}
5057
}));
5158

5259
let expanded = quote! {

cipherstash-dynamodb-derive/src/encryptable.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pub(crate) fn derive_encryptable(input: DeriveInput) -> Result<TokenStream, syn:
99
.field_attributes(&input)?
1010
.build()?;
1111

12-
let protected_attributes = settings.protected_attributes();
12+
let protected_excluding_handlers = settings.protected_attributes_excluding_handlers();
1313
let plaintext_attributes = settings.plaintext_attributes();
1414

1515
let protected_attributes_cow = settings
@@ -24,20 +24,27 @@ pub(crate) fn derive_encryptable(input: DeriveInput) -> Result<TokenStream, syn:
2424

2525
let ident = settings.ident();
2626

27-
let into_unsealed_impl = protected_attributes
27+
let into_unsealed_impl = protected_excluding_handlers
2828
.iter()
2929
.map(|attr| {
3030
let attr_ident = format_ident!("{attr}");
3131

3232
quote! {
33-
unsealed.add_protected(#attr, cipherstash_dynamodb::traits::Plaintext::from(self.#attr_ident.to_owned()));
33+
unsealed.add_protected(#attr, self.#attr_ident);
3434
}
3535
})
3636
.chain(plaintext_attributes.iter().map(|attr| {
3737
let attr_ident = format_ident!("{attr}");
3838

3939
quote! {
40-
unsealed.add_unprotected(#attr, cipherstash_dynamodb::traits::TableAttribute::from(self.#attr_ident.clone()));
40+
unsealed.add_unprotected(#attr, self.#attr_ident);
41+
}
42+
}))
43+
.chain(settings.encrypt_handlers().iter().map(|(attr, handler)| {
44+
let attr_ident = format_ident!("{attr}");
45+
46+
quote! {
47+
#handler(&mut unsealed, self.#attr_ident);
4148
}
4249
}));
4350

cipherstash-dynamodb-derive/src/settings/builder.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::{index_type::IndexType, AttributeMode, Settings};
22
use proc_macro2::{Ident, Span};
33
use std::collections::HashMap;
4-
use syn::{Data, DeriveInput, Fields, LitStr};
4+
use syn::{Data, DeriveInput, Fields, LitStr, ExprPath};
55

66
enum SortKeyPrefix {
77
Default,
@@ -31,6 +31,8 @@ pub(crate) struct SettingsBuilder {
3131
unprotected_attributes: Vec<String>,
3232
skipped_attributes: Vec<String>,
3333
indexes: Vec<IndexType>,
34+
encrypt_handlers: HashMap<String, ExprPath>,
35+
decrypt_handlers: HashMap<String, ExprPath>,
3436
}
3537

3638
impl SettingsBuilder {
@@ -58,6 +60,8 @@ impl SettingsBuilder {
5860
unprotected_attributes: Vec::new(),
5961
skipped_attributes: Vec::new(),
6062
indexes: Vec::new(),
63+
encrypt_handlers: HashMap::new(),
64+
decrypt_handlers: HashMap::new(),
6165
}
6266
}
6367

@@ -288,6 +292,18 @@ impl SettingsBuilder {
288292

289293
Ok(())
290294
}
295+
Some("encryptable_with") => {
296+
let value = meta.value()?;
297+
let handler = value.parse::<ExprPath>()?;
298+
self.encrypt_handlers.insert(field_name.clone(), handler);
299+
Ok(())
300+
}
301+
Some("decryptable_with") => {
302+
let value = meta.value()?;
303+
let handler = value.parse::<ExprPath>()?;
304+
self.decrypt_handlers.insert(field_name.clone(), handler);
305+
Ok(())
306+
}
291307
_ => Err(meta.error("unsupported field attribute")),
292308
}
293309
})?;
@@ -308,7 +324,10 @@ impl SettingsBuilder {
308324
}
309325

310326
(None, Some((compound_index_name, span))) => {
311-
return Err(syn::Error::new(span, format!("Compound attribute was specified but no query options were. Specify how this field should be queried with the attribute #[cipherstash(query = <option>, compound = \"{compound_index_name}\")]")));
327+
return Err(syn::Error::new(
328+
span,
329+
format!("Compound attribute was specified but no query options were. Specify how this field should be queried with the attribute #[cipherstash(query = <option>, compound = \"{compound_index_name}\")]"))
330+
);
312331
}
313332

314333
(None, None) => {}
@@ -345,6 +364,8 @@ impl SettingsBuilder {
345364
unprotected_attributes,
346365
skipped_attributes,
347366
indexes,
367+
encrypt_handlers,
368+
decrypt_handlers,
348369
} = self;
349370

350371
let sort_key_prefix = sort_key_prefix.into_prefix(&type_name);
@@ -359,6 +380,8 @@ impl SettingsBuilder {
359380
unprotected_attributes,
360381
skipped_attributes,
361382
indexes,
383+
encrypt_handlers,
384+
decrypt_handlers,
362385
})
363386
}
364387

@@ -499,4 +522,4 @@ impl SettingsBuilder {
499522

500523
Ok(())
501524
}
502-
}
525+
}

cipherstash-dynamodb-derive/src/settings/mod.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
mod builder;
22
pub mod index_type;
3+
use std::collections::HashMap;
4+
35
use self::{builder::SettingsBuilder, index_type::IndexType};
46
use itertools::Itertools;
57
use proc_macro2::Ident;
6-
use syn::DeriveInput;
8+
use syn::{DeriveInput, ExprPath};
79

810
pub(crate) enum AttributeMode {
911
Protected,
@@ -20,6 +22,12 @@ pub(crate) struct Settings {
2022
protected_attributes: Vec<String>,
2123
unprotected_attributes: Vec<String>,
2224

25+
/// Map of attribute names to the encryption handler to use.
26+
encrypt_handlers: HashMap<String, ExprPath>,
27+
28+
/// Map of attribute names to the decryption handler to use.
29+
decrypt_handlers: HashMap<String, ExprPath>,
30+
2331
/// Skipped attributes are never encrypted by the `DecryptedRecord` trait will
2432
/// use these to reconstruct the struct via `Default` (like serde).
2533
skipped_attributes: Vec<String>,
@@ -43,6 +51,24 @@ impl Settings {
4351
.collect::<Vec<_>>()
4452
}
4553

54+
pub(crate) fn protected_attributes_excluding_handlers(&self) -> Vec<&str> {
55+
self.protected_attributes
56+
.iter()
57+
.filter(|s| !self.encrypt_handlers.contains_key(s.as_str()))
58+
.filter(|s| !self.decrypt_handlers.contains_key(s.as_str()))
59+
.map(|s| s.as_str())
60+
.sorted()
61+
.collect::<Vec<_>>()
62+
}
63+
64+
pub(crate) fn encrypt_handlers(&self) -> &HashMap<String, ExprPath> {
65+
&self.encrypt_handlers
66+
}
67+
68+
pub(crate) fn decrypt_handlers(&self) -> &HashMap<String, ExprPath> {
69+
&self.decrypt_handlers
70+
}
71+
4672
pub(crate) fn plaintext_attributes(&self) -> Vec<&str> {
4773
self.unprotected_attributes
4874
.iter()

src/crypto/attrs/normalized_protected_attributes.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,17 @@ impl FromIterator<(NormalizedKey, NormalizedValue)> for NormalizedProtectedAttri
9696

9797
impl FromIterator<FlattenedProtectedAttribute> for NormalizedProtectedAttributes {
9898
fn from_iter<T: IntoIterator<Item = FlattenedProtectedAttribute>>(iter: T) -> Self {
99-
iter.into_iter()
100-
.fold(Self::new(), |mut acc, fpa| {
101-
match fpa.normalize_into_parts() {
102-
(plaintext, key, Some(subkey)) => {
103-
acc.insert_and_update_map(key, subkey, plaintext);
104-
}
105-
(plaintext, key, None) => {
106-
acc.insert(key, plaintext);
107-
}
99+
iter.into_iter().fold(Self::new(), |mut acc, fpa| {
100+
match fpa.normalize_into_parts() {
101+
(plaintext, key, Some(subkey)) => {
102+
acc.insert_and_update_map(key, subkey, plaintext);
108103
}
109-
acc
110-
})
104+
(plaintext, key, None) => {
105+
acc.insert(key, plaintext);
106+
}
107+
}
108+
acc
109+
})
111110
}
112111
}
113112

tests/nested_tests.rs

Lines changed: 12 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
mod common;
2-
3-
// TODO: Use the derive macros for this test
42
use cipherstash_client::encryption::TypeParseError;
53
use cipherstash_dynamodb::{
64
crypto::Unsealed,
75
errors::SealError,
8-
traits::{Plaintext, TryFromPlaintext, TryFromTableAttr},
9-
Decryptable, Encryptable, EncryptedTable, Identifiable, PkSk,
6+
traits::{Plaintext, TryFromPlaintext},
7+
Decryptable, Encryptable, EncryptedTable, Identifiable,
108
};
119
use cipherstash_dynamodb_derive::Searchable;
1210
use miette::IntoDiagnostic;
13-
use std::{borrow::Cow, collections::BTreeMap};
11+
use std::collections::BTreeMap;
1412

1513
fn make_btree_map() -> BTreeMap<String, String> {
1614
let mut map = BTreeMap::new();
@@ -20,116 +18,40 @@ fn make_btree_map() -> BTreeMap<String, String> {
2018
map
2119
}
2220

23-
#[derive(Debug, Clone, PartialEq, Searchable)]
21+
#[derive(Debug, Clone, PartialEq, Searchable, Encryptable, Decryptable, Identifiable)]
2422
struct Test {
2523
#[partition_key]
2624
pub pk: String,
2725
#[sort_key]
2826
pub sk: String,
2927
pub name: String,
3028
pub age: i16,
29+
#[cipherstash(plaintext)]
3130
pub tag: String,
31+
#[cipherstash(encryptable_with = put_attrs, decryptable_with = get_attrs)]
3232
pub attrs: BTreeMap<String, String>,
3333
}
3434

35-
impl Identifiable for Test {
36-
type PrimaryKey = PkSk;
37-
38-
fn get_primary_key(&self) -> Self::PrimaryKey {
39-
PkSk(self.pk.to_string(), self.sk.to_string())
40-
}
41-
#[inline]
42-
fn type_name() -> Cow<'static, str> {
43-
std::borrow::Cow::Borrowed("test")
44-
}
45-
#[inline]
46-
fn sort_key_prefix() -> Option<Cow<'static, str>> {
47-
None
48-
}
49-
fn is_pk_encrypted() -> bool {
50-
false
51-
}
52-
fn is_sk_encrypted() -> bool {
53-
false
54-
}
55-
}
56-
5735
fn put_attrs(unsealed: &mut Unsealed, attrs: BTreeMap<String, String>) {
5836
attrs.into_iter().for_each(|(k, v)| {
5937
unsealed.add_protected_map_field("attrs", k, Plaintext::from(v));
6038
})
6139
}
6240

63-
impl Encryptable for Test {
64-
fn protected_attributes() -> Cow<'static, [Cow<'static, str>]> {
65-
Cow::Borrowed(&[
66-
Cow::Borrowed("name"),
67-
Cow::Borrowed("age"),
68-
Cow::Borrowed("attrs"),
69-
])
70-
}
71-
72-
fn plaintext_attributes() -> Cow<'static, [Cow<'static, str>]> {
73-
Cow::Borrowed(&[
74-
Cow::Borrowed("tag"),
75-
Cow::Borrowed("pk"),
76-
Cow::Borrowed("sk"),
77-
])
78-
}
79-
80-
fn into_unsealed(self) -> Unsealed {
81-
let mut unsealed = Unsealed::new_with_descriptor(<Self as Identifiable>::type_name());
82-
unsealed.add_unprotected("pk", self.pk);
83-
unsealed.add_unprotected("sk", self.sk);
84-
unsealed.add_protected("name", self.name);
85-
unsealed.add_protected("age", self.age);
86-
unsealed.add_unprotected("tag", self.tag);
87-
put_attrs(&mut unsealed, self.attrs);
88-
println!("Encryption: {:?}", unsealed);
89-
unsealed
90-
}
91-
}
92-
93-
fn get_attrs<T>(unsealed: &mut Unsealed) -> Result<T, TypeParseError>
41+
fn get_attrs<T>(unsealed: &mut Unsealed) -> Result<T, SealError>
9442
where
9543
T: FromIterator<(String, String)>,
9644
{
9745
unsealed
9846
.take_protected_map("attrs")
9947
.ok_or(TypeParseError("attrs".to_string()))?
10048
.into_iter()
101-
.map(|(k, v)| TryFromPlaintext::try_from_plaintext(v).map(|v| (k, v)))
102-
.collect()
103-
}
104-
105-
impl Decryptable for Test {
106-
fn from_unsealed(mut unsealed: Unsealed) -> Result<Self, SealError> {
107-
println!("{:?}", unsealed);
108-
Ok(Self {
109-
pk: TryFromTableAttr::try_from_table_attr(unsealed.get_plaintext("pk"))?,
110-
sk: TryFromTableAttr::try_from_table_attr(unsealed.get_plaintext("sk"))?,
111-
name: TryFromPlaintext::try_from_optional_plaintext(unsealed.take_protected("name"))?,
112-
age: TryFromPlaintext::try_from_optional_plaintext(unsealed.take_protected("age"))?,
113-
tag: TryFromTableAttr::try_from_table_attr(unsealed.get_plaintext("tag"))?,
114-
attrs: get_attrs(&mut unsealed)?,
49+
.map(|(k, v)| {
50+
TryFromPlaintext::try_from_plaintext(v)
51+
.map(|v| (k, v))
52+
.map_err(SealError::from)
11553
})
116-
}
117-
118-
fn protected_attributes() -> Cow<'static, [Cow<'static, str>]> {
119-
Cow::Borrowed(&[
120-
Cow::Borrowed("name"),
121-
Cow::Borrowed("age"),
122-
Cow::Borrowed("attrs"),
123-
])
124-
}
125-
126-
fn plaintext_attributes() -> Cow<'static, [Cow<'static, str>]> {
127-
Cow::Borrowed(&[
128-
Cow::Borrowed("tag"),
129-
Cow::Borrowed("pk"),
130-
Cow::Borrowed("sk"),
131-
])
132-
}
54+
.collect()
13355
}
13456

13557
#[tokio::test]

0 commit comments

Comments
 (0)