diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 40bc6ead4..f3d79d460 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -6,7 +6,7 @@ set -o pipefail source .evergreen/env.sh source .evergreen/cargo-test.sh -FEATURE_FLAGS+=("tracing-unstable") +FEATURE_FLAGS+=("tracing-unstable" "cert-key-password") if [ "$ASYNC_STD" = true ]; then CARGO_OPTIONS+=("--no-default-features") diff --git a/Cargo.toml b/Cargo.toml index 7e90c93a2..c0824d178 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,6 +87,8 @@ in-use-encryption-unstable = ["mongocrypt", "rayon", "num_cpus"] # TODO: pending https://github.com/tokio-rs/tracing/issues/2036 stop depending directly on log. tracing-unstable = ["tracing", "log"] +cert-key-password = ["pem", "pkcs8"] + [dependencies] async-executor = { version = "=1.5.1", optional = true } # TODO RUST-1768: remove this async-trait = "0.1.42" @@ -110,7 +112,9 @@ mongocrypt = { version = "0.1.2", optional = true } num_cpus = { version = "1.13.1", optional = true } openssl = { version = "0.10.38", optional = true } openssl-probe = { version = "0.1.5", optional = true } +pem = { version = "3.0.4", optional = true } percent-encoding = "2.0.0" +pkcs8 = { version = "0.10.2", features = ["encryption", "pkcs5"], optional = true } rand = { version = "0.8.3", features = ["small_rng"] } rayon = { version = "1.5.3", optional = true } rustc_version_runtime = "0.2.1" diff --git a/src/client/options.rs b/src/client/options.rs index a7e37e5b4..2703b9a8e 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -1057,6 +1057,10 @@ pub struct TlsOptions { /// The default value is to error on invalid hostnames. #[cfg(feature = "openssl-tls")] pub allow_invalid_hostnames: Option, + + /// If set, the key in `cert_key_file_path` must be encrypted with this password. + #[cfg(feature = "cert-key-password")] + pub tls_certificate_key_file_password: Option>, } impl TlsOptions { @@ -1074,6 +1078,8 @@ impl TlsOptions { tlscafile: Option<&'a str>, tlscertificatekeyfile: Option<&'a str>, tlsallowinvalidcertificates: Option, + #[cfg(feature = "cert-key-password")] + tlscertificatekeyfilepassword: Option<&'a str>, } let state = TlsOptionsHelper { @@ -1087,6 +1093,11 @@ impl TlsOptions { .as_ref() .map(|s| s.to_str().unwrap()), tlsallowinvalidcertificates: tls_options.allow_invalid_certificates, + #[cfg(feature = "cert-key-password")] + tlscertificatekeyfilepassword: tls_options + .tls_certificate_key_file_password + .as_deref() + .map(|b| std::str::from_utf8(b).unwrap()), }; state.serialize(serializer) } @@ -2328,6 +2339,25 @@ impl ConnectionString { )) } }, + #[cfg(feature = "cert-key-password")] + "tlscertificatekeyfilepassword" => match &mut self.tls { + Some(Tls::Disabled) => { + return Err(ErrorKind::InvalidArgument { + message: "'tlsCertificateKeyFilePassword' can't be set if tls=false".into(), + } + .into()); + } + Some(Tls::Enabled(options)) => { + options.tls_certificate_key_file_password = Some(value.as_bytes().to_vec()); + } + None => { + self.tls = Some(Tls::Enabled( + TlsOptions::builder() + .tls_certificate_key_file_password(value.as_bytes().to_vec()) + .build(), + )) + } + }, "uuidrepresentation" => match value.to_lowercase().as_str() { "csharplegacy" => self.uuid_representation = Some(UuidRepresentation::CSharpLegacy), "javalegacy" => self.uuid_representation = Some(UuidRepresentation::JavaLegacy), diff --git a/src/client/options/test.rs b/src/client/options/test.rs index 992ae214a..c083fb231 100644 --- a/src/client/options/test.rs +++ b/src/client/options/test.rs @@ -59,7 +59,7 @@ async fn run_tests(path: &[&str], skipped_files: &[&str]) { "tlsInsecure is parsed correctly", // The driver does not support maxPoolSize=0 "maxPoolSize=0 does not error", - // TODO RUST-226: unskip this test + #[cfg(not(feature = "cert-key-password"))] "Valid tlsCertificateKeyFilePassword is parsed correctly", "SRV URI with custom srvServiceName", ]; diff --git a/src/runtime.rs b/src/runtime.rs index 85fb397e0..52e090ed5 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -4,6 +4,8 @@ mod http; #[cfg(feature = "async-std-runtime")] mod interval; mod join_handle; +#[cfg(feature = "cert-key-password")] +mod pem; #[cfg(any( feature = "in-use-encryption-unstable", all(test, not(feature = "sync"), not(feature = "tokio-sync")) diff --git a/src/runtime/pem.rs b/src/runtime/pem.rs new file mode 100644 index 000000000..ef3c3109f --- /dev/null +++ b/src/runtime/pem.rs @@ -0,0 +1,30 @@ +use crate::error::{ErrorKind, Result}; + +pub(crate) fn decrypt_private_key(pem_data: &[u8], password: &[u8]) -> Result> { + let pems = pem::parse_many(pem_data).map_err(|error| ErrorKind::InvalidTlsConfig { + message: format!("Could not parse pemfile: {}", error), + })?; + let mut iter = pems + .into_iter() + .filter(|pem| pem.tag() == "ENCRYPTED PRIVATE KEY"); + let encrypted_bytes = match iter.next() { + Some(pem) => pem.into_contents(), + None => { + return Err(ErrorKind::InvalidTlsConfig { + message: "No encrypted private keys found".into(), + } + .into()) + } + }; + let encrypted_key = pkcs8::EncryptedPrivateKeyInfo::try_from(encrypted_bytes.as_slice()) + .map_err(|error| ErrorKind::InvalidTlsConfig { + message: format!("Invalid encrypted private key: {}", error), + })?; + let decrypted_key = + encrypted_key + .decrypt(password) + .map_err(|error| ErrorKind::InvalidTlsConfig { + message: format!("Failed to decrypt private key: {}", error), + })?; + Ok(decrypted_key.as_bytes().to_vec()) +} diff --git a/src/runtime/tls_openssl.rs b/src/runtime/tls_openssl.rs index 761a047bb..cca4816c1 100644 --- a/src/runtime/tls_openssl.rs +++ b/src/runtime/tls_openssl.rs @@ -40,11 +40,7 @@ impl TlsConfig { None => true, }; - let connector = make_openssl_connector(options).map_err(|e| { - Error::from(ErrorKind::InvalidTlsConfig { - message: e.to_string(), - }) - })?; + let connector = make_openssl_connector(options)?; Ok(TlsConfig { connector, @@ -105,25 +101,50 @@ impl AsyncWrite for AsyncTlsStream { } } -fn make_openssl_connector(cfg: TlsOptions) -> std::result::Result { - let mut builder = SslConnector::builder(SslMethod::tls_client())?; +fn make_openssl_connector(cfg: TlsOptions) -> Result { + let openssl_err = |e: ErrorStack| { + Error::from(ErrorKind::InvalidTlsConfig { + message: e.to_string(), + }) + }; + + let mut builder = SslConnector::builder(SslMethod::tls_client()).map_err(openssl_err)?; let TlsOptions { allow_invalid_certificates, ca_file_path, cert_key_file_path, allow_invalid_hostnames: _, + #[cfg(feature = "cert-key-password")] + tls_certificate_key_file_password, } = cfg; if let Some(true) = allow_invalid_certificates { builder.set_verify(SslVerifyMode::NONE); } if let Some(path) = ca_file_path { - builder.set_ca_file(path)?; + builder.set_ca_file(path).map_err(openssl_err)?; } if let Some(path) = cert_key_file_path { - builder.set_certificate_file(path.clone(), SslFiletype::PEM)?; - builder.set_private_key_file(path, SslFiletype::PEM)?; + builder + .set_certificate_file(path.clone(), SslFiletype::PEM) + .map_err(openssl_err)?; + // Inner fn so the cert-key-password path can early return + let handle_private_key = || -> Result<()> { + #[cfg(feature = "cert-key-password")] + if let Some(key_pw) = tls_certificate_key_file_password { + let contents = std::fs::read(&path)?; + let key_bytes = super::pem::decrypt_private_key(&contents, &key_pw)?; + let key = + openssl::pkey::PKey::private_key_from_der(&key_bytes).map_err(openssl_err)?; + builder.set_private_key(&key).map_err(openssl_err)?; + return Ok(()); + } + builder + .set_private_key_file(path, SslFiletype::PEM) + .map_err(openssl_err) + }; + handle_private_key()?; } Ok(builder.build()) diff --git a/src/runtime/tls_rustls.rs b/src/runtime/tls_rustls.rs index 2f66b3c25..e2fa4be0a 100644 --- a/src/runtime/tls_rustls.rs +++ b/src/runtime/tls_rustls.rs @@ -141,6 +141,13 @@ fn make_rustls_config(cfg: TlsOptions) -> Result { file.rewind()?; let key = loop { + #[cfg(feature = "cert-key-password")] + if let Some(key_pw) = cfg.tls_certificate_key_file_password.as_deref() { + use std::io::Read; + let mut contents = vec![]; + file.read_to_end(&mut contents)?; + break rustls::PrivateKey(super::pem::decrypt_private_key(&contents, key_pw)?); + } match read_one(&mut file) { Ok(Some(Item::PKCS8Key(bytes))) | Ok(Some(Item::RSAKey(bytes))) => { break rustls::PrivateKey(bytes)