Skip to content

Commit 31dc8df

Browse files
feat: Allow users to specify a prebuilt 'rustls' configuration for TLS
1 parent 064d649 commit 31dc8df

File tree

9 files changed

+320
-152
lines changed

9 files changed

+320
-152
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlx-core/src/net/tls/mod.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,25 @@ impl std::fmt::Display for CertificateInput {
5757
}
5858
}
5959

60-
pub struct TlsConfig<'a> {
60+
#[derive(Debug, Clone)]
61+
#[non_exhaustive]
62+
pub enum TlsConfig<'a> {
63+
RawTlsConfig(RawTlsConfig<'a>),
64+
#[cfg(feature = "_tls-rustls")]
65+
PrebuiltRustls {
66+
config: &'a rustls::ClientConfig,
67+
hostname: &'a str,
68+
},
69+
}
70+
71+
#[derive(Debug, Clone)]
72+
pub struct RawTlsConfig<'a> {
6173
pub accept_invalid_certs: bool,
6274
pub accept_invalid_hostnames: bool,
6375
pub hostname: &'a str,
64-
pub root_cert_path: Option<&'a CertificateInput>,
65-
pub client_cert_path: Option<&'a CertificateInput>,
66-
pub client_key_path: Option<&'a CertificateInput>,
76+
pub root_cert: Option<&'a CertificateInput>,
77+
pub client_cert: Option<&'a CertificateInput>,
78+
pub client_key: Option<&'a CertificateInput>,
6779
}
6880

6981
pub async fn handshake<S, Ws>(

sqlx-core/src/net/tls/tls_native_tls.rs

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::io::{self, Read, Write};
22

33
use crate::io::ReadBuf;
44
use crate::net::tls::util::StdSocket;
5+
use crate::net::tls::RawTlsConfig;
56
use crate::net::tls::TlsConfig;
67
use crate::net::Socket;
78
use crate::rt;
@@ -39,36 +40,56 @@ impl<S: Socket> Socket for NativeTlsSocket<S> {
3940
}
4041
}
4142

42-
pub async fn handshake<S: Socket>(
43-
socket: S,
44-
config: TlsConfig<'_>,
45-
) -> crate::Result<NativeTlsSocket<S>> {
46-
let mut builder = native_tls::TlsConnector::builder();
47-
48-
builder
49-
.danger_accept_invalid_certs(config.accept_invalid_certs)
50-
.danger_accept_invalid_hostnames(config.accept_invalid_hostnames);
43+
impl TlsConfig<'_> {
44+
async fn native_tls_connector(&self) -> crate::Result<(native_tls::TlsConnector, &str), Error> {
45+
#[allow(irrefutable_let_patterns)]
46+
let TlsConfig::RawTlsConfig(RawTlsConfig {
47+
root_cert,
48+
client_cert,
49+
client_key,
50+
accept_invalid_certs,
51+
accept_invalid_hostnames,
52+
hostname,
53+
}) = self
54+
else {
55+
unreachable!()
56+
};
57+
let mut builder = native_tls::TlsConnector::builder();
58+
59+
builder
60+
.danger_accept_invalid_certs(*accept_invalid_certs)
61+
.danger_accept_invalid_hostnames(*accept_invalid_hostnames);
62+
63+
if let Some(root_cert) = root_cert {
64+
let data = root_cert.data().await?;
65+
builder.add_root_certificate(
66+
native_tls::Certificate::from_pem(&data).map_err(Error::tls)?,
67+
);
68+
}
5169

52-
if let Some(root_cert_path) = config.root_cert_path {
53-
let data = root_cert_path.data().await?;
54-
builder.add_root_certificate(native_tls::Certificate::from_pem(&data).map_err(Error::tls)?);
55-
}
70+
// authentication using user's key-file and its associated certificate
71+
if let (Some(cert), Some(key)) = (client_cert, client_key) {
72+
let cert = cert.data().await?;
73+
let key = key.data().await?;
74+
let identity = Identity::from_pkcs8(&cert, &key).map_err(Error::tls)?;
75+
builder.identity(identity);
76+
}
5677

57-
// authentication using user's key-file and its associated certificate
58-
if let (Some(cert_path), Some(key_path)) = (config.client_cert_path, config.client_key_path) {
59-
let cert_path = cert_path.data().await?;
60-
let key_path = key_path.data().await?;
61-
let identity = Identity::from_pkcs8(&cert_path, &key_path).map_err(Error::tls)?;
62-
builder.identity(identity);
78+
// The openssl TlsConnector synchronously loads certificates from files.
79+
// Loading these files can block for tens of milliseconds.
80+
let connector = rt::spawn_blocking(move || builder.build())
81+
.await
82+
.map_err(Error::tls)?;
83+
Ok((connector, hostname))
6384
}
85+
}
6486

65-
// The openssl TlsConnector synchronously loads certificates from files.
66-
// Loading these files can block for tens of milliseconds.
67-
let connector = rt::spawn_blocking(move || builder.build())
68-
.await
69-
.map_err(Error::tls)?;
70-
71-
let mut mid_handshake = match connector.connect(config.hostname, StdSocket::new(socket)) {
87+
pub async fn handshake<S: Socket>(
88+
socket: S,
89+
config: TlsConfig<'_>,
90+
) -> crate::Result<NativeTlsSocket<S>> {
91+
let (connector, hostname) = config.native_tls_connector().await?;
92+
let mut mid_handshake = match connector.connect(hostname, StdSocket::new(socket)) {
7293
Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }),
7394
Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)),
7495
Err(HandshakeError::WouldBlock(mid_handshake)) => mid_handshake,

sqlx-core/src/net/tls/tls_rustls.rs

Lines changed: 113 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use rustls::{
1919
use crate::error::Error;
2020
use crate::io::ReadBuf;
2121
use crate::net::tls::util::StdSocket;
22-
use crate::net::tls::TlsConfig;
22+
use crate::net::tls::{RawTlsConfig, TlsConfig};
2323
use crate::net::Socket;
2424

2525
pub struct RustlsSocket<S: Socket> {
@@ -87,100 +87,136 @@ impl<S: Socket> Socket for RustlsSocket<S> {
8787
}
8888
}
8989

90-
pub async fn handshake<S>(socket: S, tls_config: TlsConfig<'_>) -> Result<RustlsSocket<S>, Error>
91-
where
92-
S: Socket,
93-
{
94-
#[cfg(all(
95-
feature = "_tls-rustls-aws-lc-rs",
96-
not(feature = "_tls-rustls-ring-webpki"),
97-
not(feature = "_tls-rustls-ring-native-roots")
98-
))]
99-
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
100-
#[cfg(any(
101-
feature = "_tls-rustls-ring-webpki",
102-
feature = "_tls-rustls-ring-native-roots"
103-
))]
104-
let provider = Arc::new(rustls::crypto::ring::default_provider());
105-
106-
// Unwrapping is safe here because we use a default provider.
107-
let config = ClientConfig::builder_with_provider(provider.clone())
90+
impl TlsConfig<'_> {
91+
async fn rustls_config(&self) -> crate::Result<(rustls::ClientConfig, &str), Error> {
92+
let RawTlsConfig {
93+
accept_invalid_certs,
94+
accept_invalid_hostnames,
95+
hostname,
96+
root_cert,
97+
client_cert,
98+
client_key,
99+
} = match self {
100+
TlsConfig::RawTlsConfig(raw) => raw,
101+
TlsConfig::PrebuiltRustls { config, hostname } => {
102+
return Ok(((*config).to_owned(), hostname));
103+
}
104+
};
105+
106+
#[cfg(all(
107+
feature = "_tls-rustls-aws-lc-rs",
108+
not(feature = "_tls-rustls-ring-webpki"),
109+
not(feature = "_tls-rustls-ring-native-roots")
110+
))]
111+
let config = ClientConfig::builder_with_provider(Arc::new(
112+
rustls::crypto::aws_lc_rs::default_provider(),
113+
))
108114
.with_safe_default_protocol_versions()
109115
.unwrap();
116+
#[cfg(any(
117+
feature = "_tls-rustls-ring-webpki",
118+
feature = "_tls-rustls-ring-native-roots"
119+
))]
120+
let config =
121+
ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
122+
.with_safe_default_protocol_versions()
123+
.unwrap();
124+
#[cfg(all(
125+
not(feature = "_tls-rustls-ring-webpki"),
126+
not(feature = "_tls-rustls-ring-native-roots")
127+
))]
128+
let config = ClientConfig::builder()
129+
.with_safe_default_protocol_versions()
130+
.unwrap();
131+
132+
// authentication using user's key and its associated certificate
133+
let user_auth = match (client_cert, client_key) {
134+
(Some(cert), Some(key)) => {
135+
let cert_chain = certs_from_pem(cert.data().await?)?;
136+
let key_der = private_key_from_pem(key.data().await?)?;
137+
Some((cert_chain, key_der))
138+
}
139+
(None, None) => None,
140+
(_, _) => {
141+
return Err(Error::Configuration(
142+
"user auth key and certs must be given together".into(),
143+
))
144+
}
145+
};
110146

111-
// authentication using user's key and its associated certificate
112-
let user_auth = match (tls_config.client_cert_path, tls_config.client_key_path) {
113-
(Some(cert_path), Some(key_path)) => {
114-
let cert_chain = certs_from_pem(cert_path.data().await?)?;
115-
let key_der = private_key_from_pem(key_path.data().await?)?;
116-
Some((cert_chain, key_der))
117-
}
118-
(None, None) => None,
119-
(_, _) => {
120-
return Err(Error::Configuration(
121-
"user auth key and certs must be given together".into(),
122-
))
123-
}
124-
};
147+
let provider = config.crypto_provider().clone();
125148

126-
let config = if tls_config.accept_invalid_certs {
127-
if let Some(user_auth) = user_auth {
128-
config
129-
.dangerous()
130-
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
131-
.with_client_auth_cert(user_auth.0, user_auth.1)
132-
.map_err(Error::tls)?
149+
let config = if *accept_invalid_certs {
150+
if let Some(user_auth) = user_auth {
151+
config
152+
.dangerous()
153+
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
154+
.with_client_auth_cert(user_auth.0, user_auth.1)
155+
.map_err(Error::tls)?
156+
} else {
157+
config
158+
.dangerous()
159+
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
160+
.with_no_client_auth()
161+
}
133162
} else {
134-
config
135-
.dangerous()
136-
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
137-
.with_no_client_auth()
138-
}
139-
} else {
140-
let mut cert_store = import_root_certs();
163+
let mut cert_store = import_root_certs();
141164

142-
if let Some(ca) = tls_config.root_cert_path {
143-
let data = ca.data().await?;
165+
if let Some(ca) = root_cert {
166+
let data = ca.data().await?;
144167

145-
for result in CertificateDer::pem_slice_iter(&data) {
146-
let Ok(cert) = result else {
147-
return Err(Error::Tls(format!("Invalid certificate {ca}").into()));
148-
};
168+
for result in CertificateDer::pem_slice_iter(&data) {
169+
let Ok(cert) = result else {
170+
return Err(Error::Tls(format!("Invalid certificate {ca}").into()));
171+
};
149172

150-
cert_store.add(cert).map_err(|err| Error::Tls(err.into()))?;
173+
cert_store.add(cert).map_err(|err| Error::Tls(err.into()))?;
174+
}
151175
}
152-
}
153-
154-
if tls_config.accept_invalid_hostnames {
155-
let verifier = WebPkiServerVerifier::builder(Arc::new(cert_store))
156-
.build()
157-
.map_err(|err| Error::Tls(err.into()))?;
158176

159-
if let Some(user_auth) = user_auth {
177+
if *accept_invalid_hostnames {
178+
let verifier = WebPkiServerVerifier::builder(Arc::new(cert_store))
179+
.build()
180+
.map_err(|err| Error::Tls(err.into()))?;
181+
182+
if let Some(user_auth) = user_auth {
183+
config
184+
.dangerous()
185+
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier {
186+
verifier,
187+
}))
188+
.with_client_auth_cert(user_auth.0, user_auth.1)
189+
.map_err(Error::tls)?
190+
} else {
191+
config
192+
.dangerous()
193+
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier {
194+
verifier,
195+
}))
196+
.with_no_client_auth()
197+
}
198+
} else if let Some(user_auth) = user_auth {
160199
config
161-
.dangerous()
162-
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
200+
.with_root_certificates(cert_store)
163201
.with_client_auth_cert(user_auth.0, user_auth.1)
164202
.map_err(Error::tls)?
165203
} else {
166204
config
167-
.dangerous()
168-
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
205+
.with_root_certificates(cert_store)
169206
.with_no_client_auth()
170207
}
171-
} else if let Some(user_auth) = user_auth {
172-
config
173-
.with_root_certificates(cert_store)
174-
.with_client_auth_cert(user_auth.0, user_auth.1)
175-
.map_err(Error::tls)?
176-
} else {
177-
config
178-
.with_root_certificates(cert_store)
179-
.with_no_client_auth()
180-
}
181-
};
208+
};
209+
210+
Ok((config, hostname))
211+
}
212+
}
182213

183-
let host = ServerName::try_from(tls_config.hostname.to_owned()).map_err(Error::tls)?;
214+
pub async fn handshake<S>(socket: S, tls_config: TlsConfig<'_>) -> Result<RustlsSocket<S>, Error>
215+
where
216+
S: Socket,
217+
{
218+
let (config, hostname) = tls_config.rustls_config().await?;
219+
let host = ServerName::try_from(hostname.to_owned()).map_err(Error::tls)?;
184220

185221
let mut socket = RustlsSocket {
186222
inner: StdSocket::new(socket),

sqlx-mysql/src/connection/tls.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use sqlx_core::net::tls::RawTlsConfig;
2+
13
use crate::connection::{MySqlStream, Waiting};
24
use crate::error::Error;
35
use crate::net::tls::TlsConfig;
@@ -53,17 +55,17 @@ pub(super) async fn maybe_upgrade<S: Socket>(
5355
}
5456
}
5557

56-
let tls_config = TlsConfig {
58+
let tls_config = TlsConfig::RawTlsConfig(RawTlsConfig {
5759
accept_invalid_certs: !matches!(
5860
options.ssl_mode,
5961
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
6062
),
6163
accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity),
6264
hostname: &options.host,
63-
root_cert_path: options.ssl_ca.as_ref(),
64-
client_cert_path: options.ssl_client_cert.as_ref(),
65-
client_key_path: options.ssl_client_key.as_ref(),
66-
};
65+
root_cert: options.ssl_ca.as_ref(),
66+
client_cert: options.ssl_client_cert.as_ref(),
67+
client_key: options.ssl_client_key.as_ref(),
68+
});
6769

6870
// Request TLS upgrade
6971
stream.write_packet(SslRequest {

sqlx-postgres/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ any = ["sqlx-core/any"]
1414
json = ["sqlx-core/json"]
1515
migrate = ["sqlx-core/migrate"]
1616
offline = ["sqlx-core/offline"]
17+
rustls = ["dep:rustls", "sqlx-core/_tls-rustls"]
1718

1819
# Type Integration features
1920
bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"]
@@ -27,6 +28,9 @@ time = ["dep:time", "sqlx-core/time"]
2728
uuid = ["dep:uuid", "sqlx-core/uuid"]
2829

2930
[dependencies]
31+
# TLS
32+
rustls = { version = "0.23.24", default-features = false, features = ["std", "tls12"], optional = true }
33+
3034
# Futures crates
3135
futures-channel = { version = "0.3.19", default-features = false, features = ["sink", "alloc", "std"] }
3236
futures-core = { version = "0.3.19", default-features = false }

0 commit comments

Comments
 (0)