Skip to content

Commit 324275a

Browse files
committed
Better handling of table setup for async tests
1 parent 45bf2e1 commit 324275a

File tree

5 files changed

+171
-91
lines changed

5 files changed

+171
-91
lines changed

src/encrypted_table/mod.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -397,17 +397,15 @@ impl EncryptedTable<Dynamo> {
397397
})
398398
}
399399

400+
/// Get a record from the table by primary key from the default dataset.
400401
pub async fn get<T>(&self, k: impl Into<T::PrimaryKey>) -> Result<Option<T>, GetError>
401402
where
402403
T: Decryptable + Identifiable,
403404
{
404-
// TODO: Don't unwrap
405-
let scoped_cipher = ScopedZeroKmsCipher::init(self.cipher.clone(), None)
406-
.await
407-
.unwrap();
408-
self.get_inner(k, scoped_cipher).await
405+
self.get_inner(k, None).await
409406
}
410407

408+
/// Get a record from the table by primary key from a specific dataset.
411409
pub async fn get_via<T>(
412410
&self,
413411
k: impl Into<T::PrimaryKey>,
@@ -416,21 +414,19 @@ impl EncryptedTable<Dynamo> {
416414
where
417415
T: Decryptable + Identifiable,
418416
{
419-
// TODO: Don't unwrap
420-
let scoped_cipher = ScopedZeroKmsCipher::init(self.cipher.clone(), Some(dataset_id))
421-
.await
422-
.unwrap();
423-
self.get_inner(k, scoped_cipher).await
417+
self.get_inner(k, Some(dataset_id)).await
424418
}
425419

426420
async fn get_inner<T>(
427421
&self,
428422
k: impl Into<T::PrimaryKey>,
429-
cipher: ScopedZeroKmsCipher,
423+
dataset_id: Option<Uuid>,
430424
) -> Result<Option<T>, GetError>
431425
where
432426
T: Decryptable + Identifiable,
433427
{
428+
let cipher = ScopedZeroKmsCipher::init(self.cipher.clone(), dataset_id).await?;
429+
434430
let PrimaryKeyParts { pk, sk } =
435431
encrypt_primary_key_parts(&cipher, PreparedPrimaryKey::new::<T>(k))?;
436432

@@ -451,13 +447,15 @@ impl EncryptedTable<Dynamo> {
451447
}
452448
}
453449

450+
/// Delete a record from the table by primary key from the default dataset.
454451
pub async fn delete<E: Searchable + Identifiable>(
455452
&self,
456453
k: impl Into<E::PrimaryKey>,
457454
) -> Result<(), DeleteError> {
458455
self.delete_inner::<E>(k.into(), None).await
459456
}
460457

458+
/// Delete a record from the table by primary key from a specific dataset.
461459
pub async fn delete_via<E: Searchable + Identifiable>(
462460
&self,
463461
k: impl Into<E::PrimaryKey>,
@@ -489,13 +487,15 @@ impl EncryptedTable<Dynamo> {
489487
Ok(())
490488
}
491489

490+
/// Put a record into the table using the default dataset.
492491
pub async fn put<T>(&self, record: T) -> Result<(), PutError>
493492
where
494493
T: Searchable + Identifiable,
495494
{
496495
self.put_inner(record, None).await
497496
}
498497

498+
/// Put a record into the table using a specific dataset.
499499
pub async fn put_via<T>(&self, record: T, dataset_id: Uuid) -> Result<(), PutError>
500500
where
501501
T: Searchable + Identifiable,

src/encrypted_table/query.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use cipherstash_client::encryption::{
55
};
66
use itertools::Itertools;
77
use std::{borrow::Cow, collections::HashMap, marker::PhantomData};
8+
use uuid::Uuid;
89

910
use crate::{
1011
traits::{Decryptable, Searchable},
@@ -132,14 +133,32 @@ impl<S> QueryBuilder<S, &EncryptedTable<Dynamo>>
132133
where
133134
S: Searchable + Identifiable,
134135
{
135-
// TODO: Add load_via
136+
/// Load all records of type `T` matching the query.
137+
/// The default dataset is used.
138+
///
139+
/// While a client can decrypt records from any dataset it has access to,
140+
/// queries are always scoped to a single dataset.
136141
pub async fn load<T>(self) -> Result<Vec<T>, QueryError>
137142
where
138143
T: Decryptable + Identifiable,
139144
{
140-
let scoped_cipher = ScopedZeroKmsCipher::init(self.storage.cipher.clone(), None)
141-
.await
142-
.unwrap();
145+
self.load_inner(None).await
146+
}
147+
148+
/// Similar to `load`, but the query is scoped to a specific dataset.
149+
pub async fn load_via<T>(self, dataset_id: Uuid) -> Result<Vec<T>, QueryError>
150+
where
151+
T: Decryptable + Identifiable,
152+
{
153+
self.load_inner(Some(dataset_id)).await
154+
}
155+
156+
async fn load_inner<T>(self, dataset_id: Option<Uuid>) -> Result<Vec<T>, QueryError>
157+
where
158+
T: Decryptable + Identifiable,
159+
{
160+
let scoped_cipher =
161+
ScopedZeroKmsCipher::init(self.storage.cipher.clone(), dataset_id).await?;
143162

144163
let storage = self.storage;
145164
let query = self.build()?;

src/errors/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ pub enum GetError {
4848
Encryption(#[from] EncryptionError),
4949
#[error("AwsError: {0}")]
5050
Aws(String),
51+
52+
#[error("ZeroKMS Error: {0}")]
53+
ZeroKMS(#[from] zerokms::Error),
5154
}
5255

5356
/// Error returned by `EncryptedTable::delete` when indexing and deleting records in DynamoDB
@@ -100,6 +103,9 @@ pub enum QueryError {
100103

101104
#[error(transparent)]
102105
DynamoError(#[from] SdkError<operation::query::QueryError>),
106+
107+
#[error("ZeroKMS Error: {0}")]
108+
ZeroKMS(#[from] zerokms::Error),
103109
}
104110

105111
pub trait DynamoError: std::error::Error + Sized {}

tests/common.rs

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,91 @@ use aws_sdk_dynamodb::{
55
},
66
Client,
77
};
8+
use cipherstash_dynamodb::EncryptedTable;
9+
use miette::Diagnostic;
10+
use std::{env, future::Future, sync::OnceLock};
11+
use uuid::Uuid;
812

9-
pub async fn create_table(client: &Client, table_name: &str) {
13+
static SECONDARY_DATASET_ID: OnceLock<Uuid> = OnceLock::new();
14+
15+
#[derive(Debug, thiserror::Error, Diagnostic)]
16+
#[error("Check failed: {0}")]
17+
pub struct CheckFailed(String);
18+
19+
#[allow(dead_code)]
20+
pub fn check_eq<A, B>(a: A, b: B) -> miette::Result<()>
21+
where
22+
A: std::fmt::Debug + PartialEq<B>,
23+
B: std::fmt::Debug,
24+
{
25+
if a == b {
26+
Ok(())
27+
} else {
28+
Err(CheckFailed(format!("Expected {:?} to equal {:?}", a, b)).into())
29+
}
30+
}
31+
32+
#[allow(dead_code)]
33+
pub fn check_err<R, E>(result: Result<R, E>) -> miette::Result<()>
34+
where
35+
E: std::fmt::Debug,
36+
R: std::fmt::Debug,
37+
{
38+
if result.is_err() {
39+
Ok(())
40+
} else {
41+
Err(CheckFailed(format!("Expected error, got {:?}", result)).into())
42+
}
43+
}
44+
45+
#[allow(dead_code)]
46+
pub fn check_none<R>(result: Option<R>) -> miette::Result<()>
47+
where
48+
R: std::fmt::Debug,
49+
{
50+
if result.is_none() {
51+
Ok(())
52+
} else {
53+
Err(CheckFailed(format!("Expected None, got {:?}", result)).into())
54+
}
55+
}
56+
57+
#[allow(dead_code)]
58+
pub fn fail_not_found() -> CheckFailed {
59+
CheckFailed("Record not found".into())
60+
}
61+
62+
/// Run a test with an encrypted table.
63+
/// The table will be created before the test and deleted after the test.
64+
/// The name is used as a prefix in case its helpful to distinguish between tests.
65+
/// A random is appended to the name to ensure uniqueness for async tests.
66+
#[allow(dead_code)]
67+
pub async fn with_encrypted_table<F: Future<Output = miette::Result<()>>>(
68+
table_name: &str,
69+
mut f: impl FnMut(EncryptedTable) -> F,
70+
) -> Result<(), Box<dyn std::error::Error>> {
71+
let config = aws_config::from_env()
72+
.endpoint_url("http://localhost:8000")
73+
.load()
74+
.await;
75+
76+
let table_name = format!("{}-{}", table_name, Uuid::new_v4());
77+
let client = aws_sdk_dynamodb::Client::new(&config);
78+
79+
create_table(&client, &table_name).await;
80+
let table = EncryptedTable::init(client.clone(), &table_name).await?;
81+
let result = f(table).await;
82+
83+
delete_table(&client, &table_name).await;
84+
Ok(result?)
85+
}
86+
87+
pub async fn delete_table(client: &Client, table_name: &str) {
1088
let _ = client.delete_table().table_name(table_name).send().await;
89+
}
90+
91+
pub async fn create_table(client: &Client, table_name: &str) {
92+
delete_table(client, table_name).await;
1193

1294
client
1395
.create_table()
@@ -84,6 +166,16 @@ pub async fn create_table(client: &Client, table_name: &str) {
84166
.expect("Failed to create table");
85167
}
86168

169+
#[allow(dead_code)]
170+
pub fn secondary_dataset_id() -> Uuid {
171+
*SECONDARY_DATASET_ID.get_or_init(|| {
172+
env::var("TEST_SECOND_DATASET_ID")
173+
.expect("TEST_SECOND_DATASET_ID must be set")
174+
.parse()
175+
.expect("TEST_SECOND_DATASET_ID must be a valid UUID")
176+
})
177+
}
178+
87179
#[macro_export]
88180
macro_rules! assert_err {
89181
($cond:expr,) => {

tests/round_trip_tests.rs

Lines changed: 38 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
use cipherstash_client::ZeroKMSConfig;
21
use cipherstash_dynamodb::{
3-
encrypted_table::Dynamo, Decryptable, Encryptable, EncryptedTable, Identifiable, Searchable,
2+
Decryptable, Encryptable, Identifiable, Searchable,
43
};
5-
use miette::IntoDiagnostic;
4+
use common::{check_eq, check_err, check_none, fail_not_found, secondary_dataset_id, with_encrypted_table};
65
use uuid::Uuid;
76
mod common;
87

@@ -196,86 +195,50 @@ fn build_test_record(email: &str, name: &str) -> Crazy {
196195
}
197196
}
198197

199-
async fn init_table() -> EncryptedTable<Dynamo> {
200-
let config = aws_config::from_env()
201-
.endpoint_url("http://localhost:8000")
202-
.load()
203-
.await;
204-
205-
let client = aws_sdk_dynamodb::Client::new(&config);
206-
207-
let table_name = "crazy-record";
208-
209-
common::create_table(&client, table_name).await;
210-
211-
EncryptedTable::init(client, table_name)
212-
.await
213-
.expect("Failed to init table")
214-
}
215-
216198
#[tokio::test]
217199
async fn test_round_trip() -> Result<(), Box<dyn std::error::Error>> {
218-
let table = init_table().await;
219-
let record = build_test_record("[email protected]", "Dan");
220-
table.put(record.clone()).await.into_diagnostic()?;
221-
222-
let s: Crazy = table
223-
.get(("[email protected]", "Dan"))
224-
.await
225-
.into_diagnostic()?
226-
.unwrap();
227-
228-
assert_eq!(s, record);
229-
230-
Ok(())
200+
with_encrypted_table("round-trip", |table| async move {
201+
let record = build_test_record("[email protected]", "Dan");
202+
table.put(record.clone()).await?;
203+
204+
let s: Crazy = table
205+
.get(("[email protected]", "Dan"))
206+
.await?
207+
.ok_or(fail_not_found())?;
208+
209+
check_eq(s, record)
210+
})
211+
.await
231212
}
232213

233214
#[tokio::test]
234215
async fn test_invalid_dataset() -> Result<(), Box<dyn std::error::Error>> {
235-
let table = init_table().await;
236-
let record = build_test_record("[email protected]", "Dan");
216+
with_encrypted_table("round-trip", |table| async move {
217+
let record = build_test_record("[email protected]", "Dan");
237218

238-
// A random UUID doesn't exist
239-
assert_err!(table.put_via(record.clone(), Uuid::new_v4()).await);
240-
241-
Ok(())
219+
// A random UUID doesn't exist
220+
check_err(table.put_via(record.clone(), Uuid::new_v4()).await)
221+
})
222+
.await
242223
}
243224

244225
#[tokio::test]
245-
async fn test_invalid_specific_dataset() -> miette::Result<()> {
246-
// TODO: Load client ID from env
247-
let client_id = Uuid::parse_str("b91e5b26-f21f-4694-8bce-c61c10e42301").into_diagnostic()?;
248-
let client = ZeroKMSConfig::builder()
249-
.with_env()
250-
.build()
251-
.into_diagnostic()?
252-
.create_client();
253-
254-
let dataset = client
255-
.create_dataset("test-dataset", "Test dataset")
256-
.await
257-
.into_diagnostic()?;
258-
259-
// Grant ourselves access to the dataset
260-
client
261-
.grant_dataset(client_id, dataset.id)
262-
.await
263-
.into_diagnostic()?;
264-
265-
let table = init_table().await;
266-
let record = build_test_record("[email protected]", "Person");
267-
268-
table.put_via(record.clone(), dataset.id).await?;
269-
270-
let s: Crazy = table
271-
.get_via(("[email protected]", "Person"), dataset.id)
272-
.await?
273-
.unwrap();
274-
275-
assert_eq!(s, record);
276-
277-
// Test that we can't get the record via the default dataset
278-
assert_none!(table.get::<Crazy>(("[email protected]", "Person")).await?);
279-
280-
Ok(())
226+
async fn test_invalid_specific_dataset() -> Result<(), Box<dyn std::error::Error>> {
227+
with_encrypted_table("round-trip", |table| async move {
228+
let record = build_test_record("[email protected]", "Person");
229+
table
230+
.put_via(record.clone(), secondary_dataset_id())
231+
.await?;
232+
233+
let s: Crazy = table
234+
.get_via(("[email protected]", "Person"), secondary_dataset_id())
235+
.await?
236+
.ok_or(fail_not_found())?;
237+
238+
check_eq(s, record)?;
239+
240+
// Test that we can't get the record via the default dataset
241+
check_none(table.get::<Crazy>(("[email protected]", "Person")).await?)
242+
})
243+
.await
281244
}

0 commit comments

Comments
 (0)