Skip to content

Commit 4d6d0d9

Browse files
authored
RUST-2161 Support auto encryption in unified tests (#1426)
1 parent e56ca78 commit 4d6d0d9

24 files changed

+2081
-92
lines changed

src/client/csfle.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,15 @@ pub(crate) fn aux_collections(
230230
}
231231
Ok(out)
232232
}
233+
234+
impl Client {
235+
pub(crate) async fn init_csfle(&self, opts: AutoEncryptionOptions) -> Result<()> {
236+
let mut csfle_state = self.inner.csfle.write().await;
237+
if csfle_state.is_some() {
238+
return Err(Error::internal("double initialization of csfle state"));
239+
}
240+
*csfle_state = Some(ClientState::new(self, opts).await?);
241+
242+
Ok(())
243+
}
244+
}

src/client/csfle/client_builder.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ impl EncryptedClientBuilder {
112112
/// mongocryptd as part of `Client` initialization.
113113
pub async fn build(self) -> Result<Client> {
114114
let client = Client::with_options(self.client_options)?;
115-
*client.inner.csfle.write().await =
116-
Some(super::ClientState::new(&client, self.enc_opts).await?);
115+
client.init_csfle(self.enc_opts).await?;
117116
Ok(client)
118117
}
119118
}

src/coll/action/drop.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ where
5252
}
5353
// * from a `list_collections` call:
5454
let found;
55-
if enc_fields.is_none() && client_enc_fields.is_some() {
55+
if enc_fields.is_none() && enc_opts.is_some() {
5656
let filter = doc! { "name": self.name() };
5757
let mut specs: Vec<_> = match session.as_deref_mut() {
5858
Some(s) => {

src/test/csfle.rs

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ pub(crate) type KmsProviderList = Vec<KmsInfo>;
4444
static CSFLE_LOCAL_KEY: Lazy<String> = Lazy::new(|| get_env_var("CSFLE_LOCAL_KEY"));
4545
static FLE_AWS_KEY: Lazy<String> = Lazy::new(|| get_env_var("FLE_AWS_KEY"));
4646
static FLE_AWS_SECRET: Lazy<String> = Lazy::new(|| get_env_var("FLE_AWS_SECRET"));
47+
static FLE_AWS_TEMP_KEY: Lazy<String> = Lazy::new(|| get_env_var("CSFLE_AWS_TEMP_ACCESS_KEY_ID"));
48+
static FLE_AWS_TEMP_SECRET: Lazy<String> =
49+
Lazy::new(|| get_env_var("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY"));
50+
static FLE_AWS_TEMP_SESSION_TOKEN: Lazy<String> =
51+
Lazy::new(|| get_env_var("CSFLE_AWS_TEMP_SESSION_TOKEN"));
4752
static FLE_AZURE_TENANTID: Lazy<String> = Lazy::new(|| get_env_var("FLE_AZURE_TENANTID"));
4853
static FLE_AZURE_CLIENTID: Lazy<String> = Lazy::new(|| get_env_var("FLE_AZURE_CLIENTID"));
4954
static FLE_AZURE_CLIENTSECRET: Lazy<String> = Lazy::new(|| get_env_var("FLE_AZURE_CLIENTSECRET"));
@@ -61,13 +66,16 @@ static CSFLE_TLS_CERT_DIR: Lazy<String> = Lazy::new(|| get_env_var("CSFLE_TLS_CE
6166
static CRYPT_SHARED_LIB_PATH: Lazy<String> = Lazy::new(|| get_env_var("CRYPT_SHARED_LIB_PATH"));
6267

6368
fn get_env_var(name: &str) -> String {
64-
std::env::var(name).unwrap_or_else(|_| {
65-
panic!(
66-
"Missing environment variable for {}. See src/test/csfle.rs for the list of required \
67-
variables and instructions for retrieving them.",
68-
name
69-
)
70-
})
69+
match std::env::var(name) {
70+
Ok(v) if !v.is_empty() => v,
71+
_ => {
72+
panic!(
73+
"Missing environment variable for {}. See src/test/csfle.rs for the list of \
74+
required variables and instructions for retrieving them.",
75+
name
76+
)
77+
}
78+
}
7179
}
7280

7381
pub(crate) static AWS_KMS: Lazy<KmsInfo> = Lazy::new(|| {
@@ -80,6 +88,17 @@ pub(crate) static AWS_KMS: Lazy<KmsInfo> = Lazy::new(|| {
8088
None,
8189
)
8290
});
91+
static AWS_TEMP_KMS: Lazy<KmsInfo> = Lazy::new(|| {
92+
(
93+
KmsProvider::aws(),
94+
doc! {
95+
"accessKeyId": &*FLE_AWS_TEMP_KEY,
96+
"secretAccessKey": &*FLE_AWS_TEMP_SECRET,
97+
"sessionToken": &*FLE_AWS_TEMP_SESSION_TOKEN,
98+
},
99+
None,
100+
)
101+
});
83102
pub(crate) static AWS_KMS_NAME1: Lazy<KmsInfo> = Lazy::new(|| {
84103
let aws_info = AWS_KMS.clone();
85104
(aws_info.0.with_name("name1"), aws_info.1, aws_info.2)
@@ -310,3 +329,39 @@ async fn fle2v2_ok(name: &str) -> bool {
310329
}
311330
true
312331
}
332+
333+
pub(crate) fn fill_kms_placeholders(
334+
kms_provider_map: std::collections::HashMap<mongocrypt::ctx::KmsProvider, Document>,
335+
) -> KmsProviderList {
336+
use mongocrypt::ctx::KmsProviderType;
337+
338+
let placeholder = doc! { "$$placeholder": 1 };
339+
340+
let mut kms_providers = Vec::new();
341+
for (provider, mut config) in kms_provider_map {
342+
// AWS uses temp creds if the "sessionToken" key is present in the config
343+
let test_kms_provider = if *provider.provider_type() == KmsProviderType::Aws
344+
&& config.contains_key("sessionToken")
345+
{
346+
Some(&*AWS_TEMP_KMS)
347+
} else {
348+
(*ALL_KMS_PROVIDERS).iter().find(|(p, ..)| p == &provider)
349+
};
350+
351+
for (key, value) in config.iter_mut() {
352+
if value.as_document() == Some(&placeholder) {
353+
let test_kms_provider = test_kms_provider
354+
.unwrap_or_else(|| panic!("missing config for {:?}", provider));
355+
let placeholder_value = test_kms_provider.1.get(key).unwrap_or_else(|| {
356+
panic!("provider config {:?} missing key {:?}", provider, key)
357+
});
358+
*value = placeholder_value.clone();
359+
}
360+
}
361+
362+
let tls_options = test_kms_provider.and_then(|(_, _, tls_options)| tls_options.clone());
363+
kms_providers.push((provider, config, tls_options));
364+
}
365+
366+
kms_providers
367+
}

src/test/csfle/spec.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,16 @@ async fn run_unified() {
1616

1717
#[tokio::test(flavor = "multi_thread")]
1818
async fn run_legacy() {
19-
// TODO RUST-528: unskip this file
20-
let mut skipped_files = vec!["timeoutMS.json"];
19+
let mut skipped_files = vec![
20+
// TODO RUST-528: unskip this file
21+
"timeoutMS.json",
22+
// These files have been migrated to unified tests.
23+
// TODO DRIVERS-3178 remove these once the files are gone.
24+
"fle2v2-BypassQueryAnalysis.json",
25+
"fle2v2-EncryptedFields-vs-EncryptedFieldsMap.json",
26+
"localSchema.json",
27+
"maxWireVersion.json",
28+
];
2129
if cfg!(not(feature = "openssl-tls")) {
2230
skipped_files.push("kmipKMS.json");
2331
}

0 commit comments

Comments
 (0)