Skip to content

Commit 3f762ad

Browse files
committed
Added dataset aware Put and Get for EncryptedTable
1 parent 72f76cd commit 3f762ad

File tree

5 files changed

+139
-35
lines changed

5 files changed

+139
-35
lines changed

src/crypto/sealer.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ impl Sealer {
135135
records: impl IntoIterator<Item = Sealer>,
136136
protected_attributes: impl AsRef<[Cow<'a, str>]>,
137137
cipher: &ScopedZeroKmsCipher,
138-
// FIXME: This might need to be a const generic
139-
term_length: usize,
140138
) -> Result<RecordsWithTerms, SealError> {
141139
let protected_attributes = protected_attributes.as_ref();
142140
let num_protected_attributes = protected_attributes.len();
@@ -211,9 +209,8 @@ impl Sealer {
211209
records: impl IntoIterator<Item = Sealer>,
212210
protected_attributes: impl AsRef<[Cow<'a, str>]>,
213211
cipher: &ScopedZeroKmsCipher,
214-
term_length: usize,
215212
) -> Result<Vec<Sealed>, SealError> {
216-
Self::index_all_terms(records, protected_attributes, &cipher, term_length)?
213+
Self::index_all_terms(records, protected_attributes, &cipher)?
217214
.encrypt(cipher)
218215
.await
219216
}
@@ -222,9 +219,8 @@ impl Sealer {
222219
self,
223220
protected_attributes: impl AsRef<[Cow<'a, str>]>,
224221
cipher: &ScopedZeroKmsCipher,
225-
term_length: usize,
226222
) -> Result<Sealed, SealError> {
227-
let mut vec = Self::seal_all([self], protected_attributes, cipher, term_length).await?;
223+
let mut vec = Self::seal_all([self], protected_attributes, cipher).await?;
228224

229225
if vec.len() != 1 {
230226
let actual = vec.len();

src/encrypted_table/mod.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ use std::{
3131
ops::Deref, sync::Arc,
3232
};
3333

34-
/// Index terms are truncated to this length
35-
const DEFAULT_TERM_SIZE: usize = 12;
36-
3734
pub struct Headless;
3835

3936
pub struct Dynamo {
@@ -324,7 +321,7 @@ impl<D> EncryptedTable<D> {
324321
) -> Result<DynamoRecordPatch, PutError> {
325322
let mut seen_sk = HashSet::new();
326323

327-
let indexable_cipher = ScopedZeroKmsCipher::init(self.cipher.clone(), dataset_id).await.unwrap();
324+
let indexable_cipher = ScopedZeroKmsCipher::init(self.cipher.clone(), dataset_id).await?;
328325

329326
let PreparedRecord {
330327
protected_attributes,
@@ -334,7 +331,7 @@ impl<D> EncryptedTable<D> {
334331

335332
// Do the encryption
336333
let sealed = sealer
337-
.seal(protected_attributes, &indexable_cipher, DEFAULT_TERM_SIZE)
334+
.seal(protected_attributes, &indexable_cipher)
338335
.await?;
339336

340337
let mut put_records = Vec::with_capacity(sealed.len());

src/errors/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ pub enum PutError {
3434

3535
#[error(transparent)]
3636
DynamoError(#[from] SdkError<operation::transact_write_items::TransactWriteItemsError>),
37+
38+
#[error("ZeroKMS Error: {0}")]
39+
ZeroKMS(#[from] zerokms::Error),
3740
}
3841

3942
/// Error returned by `EncryptedTable::get` when retrieving and decrypting records from DynamoDB

tests/common.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,51 @@ pub async fn create_table(client: &Client, table_name: &str) {
8383
.await
8484
.expect("Failed to create table");
8585
}
86+
87+
88+
89+
#[macro_export]
90+
macro_rules! assert_err {
91+
($cond:expr,) => {
92+
$crate::assert_err!($cond);
93+
};
94+
($cond:expr) => {
95+
match $cond {
96+
Ok(t) => {
97+
panic!("assertion failed, expected Err(..), got Ok({:?})", t);
98+
},
99+
Err(e) => e,
100+
}
101+
};
102+
($cond:expr, $($arg:tt)+) => {
103+
match $cond {
104+
Ok(t) => {
105+
panic!("assertion failed, expected Err(..), got Ok({:?}): {}", t, format_args!($($arg)+));
106+
},
107+
Err(e) => e,
108+
}
109+
};
110+
}
111+
112+
#[macro_export]
113+
macro_rules! assert_none {
114+
($cond:expr,) => {
115+
$crate::assert_none!($cond);
116+
};
117+
($cond:expr) => {
118+
match $cond {
119+
Some(t) => {
120+
panic!("assertion failed, expected Err(..), got Ok({:?})", t);
121+
},
122+
None => (),
123+
}
124+
};
125+
($cond:expr, $($arg:tt)+) => {
126+
match $cond {
127+
Ok(t) => {
128+
panic!("assertion failed, expected None, got Some({:?}): {}", t, format_args!($($arg)+));
129+
},
130+
Err(e) => (),
131+
}
132+
};
133+
}

tests/round_trip_tests.rs

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
use cipherstash_dynamodb::{Decryptable, Encryptable, EncryptedTable, Identifiable, Searchable};
1+
use cipherstash_client::ZeroKMSConfig;
2+
use cipherstash_dynamodb::{encrypted_table::Dynamo, Decryptable, Encryptable, EncryptedTable, Identifiable, Searchable};
23
use miette::IntoDiagnostic;
4+
use uuid::Uuid;
35
mod common;
46

57
#[derive(Debug, Clone, PartialEq, Identifiable, Encryptable, Decryptable, Searchable)]
@@ -123,26 +125,10 @@ struct Crazy {
123125
pt_k_none: Option<Vec<Vec<u8>>>,
124126
}
125127

126-
#[tokio::test]
127-
async fn test_round_trip() -> Result<(), Box<dyn std::error::Error>> {
128-
let config = aws_config::from_env()
129-
.endpoint_url("http://localhost:8000")
130-
.load()
131-
.await;
132-
133-
let client = aws_sdk_dynamodb::Client::new(&config);
134-
135-
let table_name = "crazy-record";
136-
137-
common::create_table(&client, table_name).await;
138-
139-
let table = EncryptedTable::init(client, table_name)
140-
.await
141-
.expect("Failed to init table");
142-
143-
let r = Crazy {
144-
email: "[email protected]".into(),
145-
name: "Dan".into(),
128+
fn build_test_record(email: &str, name: &str) -> Crazy {
129+
Crazy {
130+
email: email.into(),
131+
name: name.into(),
146132

147133
ct_a: 123,
148134
ct_b: 321,
@@ -205,17 +191,91 @@ async fn test_round_trip() -> Result<(), Box<dyn std::error::Error>> {
205191
pt_i_none: None,
206192
pt_j_none: None,
207193
pt_k_none: None,
208-
};
194+
}
195+
}
196+
197+
async fn init_table() -> EncryptedTable<Dynamo> {
198+
let config = aws_config::from_env()
199+
.endpoint_url("http://localhost:8000")
200+
.load()
201+
.await;
202+
203+
let client = aws_sdk_dynamodb::Client::new(&config);
204+
205+
let table_name = "crazy-record";
206+
207+
common::create_table(&client, table_name).await;
208+
209+
EncryptedTable::init(client, table_name)
210+
.await
211+
.expect("Failed to init table")
212+
}
209213

210-
table.put(r.clone()).await.into_diagnostic()?;
214+
#[tokio::test]
215+
async fn test_round_trip() -> Result<(), Box<dyn std::error::Error>> {
216+
let table = init_table().await;
217+
let record = build_test_record("[email protected]", "Dan");
218+
table.put(record.clone()).await.into_diagnostic()?;
211219

212220
let s: Crazy = table
213221
.get(("[email protected]", "Dan"))
214222
.await
215223
.into_diagnostic()?
216224
.unwrap();
217225

218-
assert_eq!(s, r);
226+
assert_eq!(s, record);
227+
228+
Ok(())
229+
}
230+
231+
#[tokio::test]
232+
async fn test_invalid_dataset() -> Result<(), Box<dyn std::error::Error>> {
233+
let table = init_table().await;
234+
let record = build_test_record("[email protected]", "Dan");
235+
236+
// A random UUID doesn't exist
237+
assert_err!(table.put_via(record.clone(), Uuid::new_v4()).await);
238+
239+
Ok(())
240+
}
241+
242+
#[tokio::test]
243+
async fn test_invalid_specific_dataset() -> miette::Result<()> {
244+
// TODO: Load client ID from env
245+
let client_id = Uuid::parse_str("b91e5b26-f21f-4694-8bce-c61c10e42301").into_diagnostic()?;
246+
let client = ZeroKMSConfig::builder()
247+
.with_env()
248+
.build()
249+
.into_diagnostic()?
250+
.create_client();
251+
252+
let dataset = client
253+
.create_dataset("test-dataset", "Test dataset")
254+
.await
255+
.into_diagnostic()?;
256+
257+
// Grant ourselves access to the dataset
258+
client.grant_dataset(client_id, dataset.id)
259+
.await
260+
.into_diagnostic()?;
261+
262+
let table = init_table().await;
263+
let record = build_test_record("[email protected]", "Person");
264+
265+
table.put_via(record.clone(), dataset.id).await?;
266+
267+
let s: Crazy = table
268+
.get_via(("[email protected]", "Person"), dataset.id)
269+
.await?
270+
.unwrap();
271+
272+
assert_eq!(s, record);
273+
274+
// Test that we can't get the record via the default dataset
275+
assert_none!(table
276+
.get::<Crazy>(("[email protected]", "Person"))
277+
.await?);
278+
219279

220280
Ok(())
221281
}

0 commit comments

Comments
 (0)