Skip to content

Commit 732dc54

Browse files
authored
RUST-226 Support tlsCertificateKeyFilePassword (#1256) (#1257)
1 parent 1704eb6 commit 732dc54

File tree

8 files changed

+106
-12
lines changed

8 files changed

+106
-12
lines changed

.evergreen/run-tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set -o pipefail
66
source .evergreen/env.sh
77
source .evergreen/cargo-test.sh
88

9-
FEATURE_FLAGS+=("tracing-unstable")
9+
FEATURE_FLAGS+=("tracing-unstable" "cert-key-password")
1010

1111
if [ "$ASYNC_STD" = true ]; then
1212
CARGO_OPTIONS+=("--no-default-features")

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ in-use-encryption-unstable = ["mongocrypt", "rayon", "num_cpus"]
8787
# TODO: pending https://github.com/tokio-rs/tracing/issues/2036 stop depending directly on log.
8888
tracing-unstable = ["tracing", "log"]
8989

90+
cert-key-password = ["pem", "pkcs8"]
91+
9092
[dependencies]
9193
async-executor = { version = "=1.5.1", optional = true } # TODO RUST-1768: remove this
9294
async-trait = "0.1.42"
@@ -110,7 +112,9 @@ mongocrypt = { version = "0.1.2", optional = true }
110112
num_cpus = { version = "1.13.1", optional = true }
111113
openssl = { version = "0.10.38", optional = true }
112114
openssl-probe = { version = "0.1.5", optional = true }
115+
pem = { version = "3.0.4", optional = true }
113116
percent-encoding = "2.0.0"
117+
pkcs8 = { version = "0.10.2", features = ["encryption", "pkcs5"], optional = true }
114118
rand = { version = "0.8.3", features = ["small_rng"] }
115119
rayon = { version = "1.5.3", optional = true }
116120
rustc_version_runtime = "0.2.1"

src/client/options.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,10 @@ pub struct TlsOptions {
10571057
/// The default value is to error on invalid hostnames.
10581058
#[cfg(feature = "openssl-tls")]
10591059
pub allow_invalid_hostnames: Option<bool>,
1060+
1061+
/// If set, the key in `cert_key_file_path` must be encrypted with this password.
1062+
#[cfg(feature = "cert-key-password")]
1063+
pub tls_certificate_key_file_password: Option<Vec<u8>>,
10601064
}
10611065

10621066
impl TlsOptions {
@@ -1074,6 +1078,8 @@ impl TlsOptions {
10741078
tlscafile: Option<&'a str>,
10751079
tlscertificatekeyfile: Option<&'a str>,
10761080
tlsallowinvalidcertificates: Option<bool>,
1081+
#[cfg(feature = "cert-key-password")]
1082+
tlscertificatekeyfilepassword: Option<&'a str>,
10771083
}
10781084

10791085
let state = TlsOptionsHelper {
@@ -1087,6 +1093,11 @@ impl TlsOptions {
10871093
.as_ref()
10881094
.map(|s| s.to_str().unwrap()),
10891095
tlsallowinvalidcertificates: tls_options.allow_invalid_certificates,
1096+
#[cfg(feature = "cert-key-password")]
1097+
tlscertificatekeyfilepassword: tls_options
1098+
.tls_certificate_key_file_password
1099+
.as_deref()
1100+
.map(|b| std::str::from_utf8(b).unwrap()),
10901101
};
10911102
state.serialize(serializer)
10921103
}
@@ -2328,6 +2339,25 @@ impl ConnectionString {
23282339
))
23292340
}
23302341
},
2342+
#[cfg(feature = "cert-key-password")]
2343+
"tlscertificatekeyfilepassword" => match &mut self.tls {
2344+
Some(Tls::Disabled) => {
2345+
return Err(ErrorKind::InvalidArgument {
2346+
message: "'tlsCertificateKeyFilePassword' can't be set if tls=false".into(),
2347+
}
2348+
.into());
2349+
}
2350+
Some(Tls::Enabled(options)) => {
2351+
options.tls_certificate_key_file_password = Some(value.as_bytes().to_vec());
2352+
}
2353+
None => {
2354+
self.tls = Some(Tls::Enabled(
2355+
TlsOptions::builder()
2356+
.tls_certificate_key_file_password(value.as_bytes().to_vec())
2357+
.build(),
2358+
))
2359+
}
2360+
},
23312361
"uuidrepresentation" => match value.to_lowercase().as_str() {
23322362
"csharplegacy" => self.uuid_representation = Some(UuidRepresentation::CSharpLegacy),
23332363
"javalegacy" => self.uuid_representation = Some(UuidRepresentation::JavaLegacy),

src/client/options/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async fn run_tests(path: &[&str], skipped_files: &[&str]) {
5959
"tlsInsecure is parsed correctly",
6060
// The driver does not support maxPoolSize=0
6161
"maxPoolSize=0 does not error",
62-
// TODO RUST-226: unskip this test
62+
#[cfg(not(feature = "cert-key-password"))]
6363
"Valid tlsCertificateKeyFilePassword is parsed correctly",
6464
"SRV URI with custom srvServiceName",
6565
];

src/runtime.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ mod http;
44
#[cfg(feature = "async-std-runtime")]
55
mod interval;
66
mod join_handle;
7+
#[cfg(feature = "cert-key-password")]
8+
mod pem;
79
#[cfg(any(
810
feature = "in-use-encryption-unstable",
911
all(test, not(feature = "sync"), not(feature = "tokio-sync"))

src/runtime/pem.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use crate::error::{ErrorKind, Result};
2+
3+
pub(crate) fn decrypt_private_key(pem_data: &[u8], password: &[u8]) -> Result<Vec<u8>> {
4+
let pems = pem::parse_many(pem_data).map_err(|error| ErrorKind::InvalidTlsConfig {
5+
message: format!("Could not parse pemfile: {}", error),
6+
})?;
7+
let mut iter = pems
8+
.into_iter()
9+
.filter(|pem| pem.tag() == "ENCRYPTED PRIVATE KEY");
10+
let encrypted_bytes = match iter.next() {
11+
Some(pem) => pem.into_contents(),
12+
None => {
13+
return Err(ErrorKind::InvalidTlsConfig {
14+
message: "No encrypted private keys found".into(),
15+
}
16+
.into())
17+
}
18+
};
19+
let encrypted_key = pkcs8::EncryptedPrivateKeyInfo::try_from(encrypted_bytes.as_slice())
20+
.map_err(|error| ErrorKind::InvalidTlsConfig {
21+
message: format!("Invalid encrypted private key: {}", error),
22+
})?;
23+
let decrypted_key =
24+
encrypted_key
25+
.decrypt(password)
26+
.map_err(|error| ErrorKind::InvalidTlsConfig {
27+
message: format!("Failed to decrypt private key: {}", error),
28+
})?;
29+
Ok(decrypted_key.as_bytes().to_vec())
30+
}

src/runtime/tls_openssl.rs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@ impl TlsConfig {
4040
None => true,
4141
};
4242

43-
let connector = make_openssl_connector(options).map_err(|e| {
44-
Error::from(ErrorKind::InvalidTlsConfig {
45-
message: e.to_string(),
46-
})
47-
})?;
43+
let connector = make_openssl_connector(options)?;
4844

4945
Ok(TlsConfig {
5046
connector,
@@ -105,25 +101,50 @@ impl AsyncWrite for AsyncTlsStream {
105101
}
106102
}
107103

108-
fn make_openssl_connector(cfg: TlsOptions) -> std::result::Result<SslConnector, ErrorStack> {
109-
let mut builder = SslConnector::builder(SslMethod::tls_client())?;
104+
fn make_openssl_connector(cfg: TlsOptions) -> Result<SslConnector> {
105+
let openssl_err = |e: ErrorStack| {
106+
Error::from(ErrorKind::InvalidTlsConfig {
107+
message: e.to_string(),
108+
})
109+
};
110+
111+
let mut builder = SslConnector::builder(SslMethod::tls_client()).map_err(openssl_err)?;
110112

111113
let TlsOptions {
112114
allow_invalid_certificates,
113115
ca_file_path,
114116
cert_key_file_path,
115117
allow_invalid_hostnames: _,
118+
#[cfg(feature = "cert-key-password")]
119+
tls_certificate_key_file_password,
116120
} = cfg;
117121

118122
if let Some(true) = allow_invalid_certificates {
119123
builder.set_verify(SslVerifyMode::NONE);
120124
}
121125
if let Some(path) = ca_file_path {
122-
builder.set_ca_file(path)?;
126+
builder.set_ca_file(path).map_err(openssl_err)?;
123127
}
124128
if let Some(path) = cert_key_file_path {
125-
builder.set_certificate_file(path.clone(), SslFiletype::PEM)?;
126-
builder.set_private_key_file(path, SslFiletype::PEM)?;
129+
builder
130+
.set_certificate_file(path.clone(), SslFiletype::PEM)
131+
.map_err(openssl_err)?;
132+
// Inner fn so the cert-key-password path can early return
133+
let handle_private_key = || -> Result<()> {
134+
#[cfg(feature = "cert-key-password")]
135+
if let Some(key_pw) = tls_certificate_key_file_password {
136+
let contents = std::fs::read(&path)?;
137+
let key_bytes = super::pem::decrypt_private_key(&contents, &key_pw)?;
138+
let key =
139+
openssl::pkey::PKey::private_key_from_der(&key_bytes).map_err(openssl_err)?;
140+
builder.set_private_key(&key).map_err(openssl_err)?;
141+
return Ok(());
142+
}
143+
builder
144+
.set_private_key_file(path, SslFiletype::PEM)
145+
.map_err(openssl_err)
146+
};
147+
handle_private_key()?;
127148
}
128149

129150
Ok(builder.build())

src/runtime/tls_rustls.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ fn make_rustls_config(cfg: TlsOptions) -> Result<rustls::ClientConfig> {
141141

142142
file.rewind()?;
143143
let key = loop {
144+
#[cfg(feature = "cert-key-password")]
145+
if let Some(key_pw) = cfg.tls_certificate_key_file_password.as_deref() {
146+
use std::io::Read;
147+
let mut contents = vec![];
148+
file.read_to_end(&mut contents)?;
149+
break rustls::PrivateKey(super::pem::decrypt_private_key(&contents, key_pw)?);
150+
}
144151
match read_one(&mut file) {
145152
Ok(Some(Item::PKCS8Key(bytes))) | Ok(Some(Item::RSAKey(bytes))) => {
146153
break rustls::PrivateKey(bytes)

0 commit comments

Comments
 (0)