Skip to content

Commit 1be3a99

Browse files
feat: Allow users to specify a prebuilt 'rustls' configuration for TLS
1 parent 064d649 commit 1be3a99

File tree

9 files changed

+311
-155
lines changed

9 files changed

+311
-155
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlx-core/src/net/tls/mod.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,25 @@ impl std::fmt::Display for CertificateInput {
5757
}
5858
}
5959

60-
pub struct TlsConfig<'a> {
60+
#[derive(Debug, Clone)]
61+
#[non_exhaustive]
62+
pub enum TlsConfig<'a> {
63+
RawTlsConfig(RawTlsConfig<'a>),
64+
#[cfg(feature = "_tls-rustls")]
65+
PrebuiltRustls {
66+
config: &'a rustls::ClientConfig,
67+
hostname: &'a str,
68+
},
69+
}
70+
71+
#[derive(Debug, Clone)]
72+
pub struct RawTlsConfig<'a> {
6173
pub accept_invalid_certs: bool,
6274
pub accept_invalid_hostnames: bool,
6375
pub hostname: &'a str,
64-
pub root_cert_path: Option<&'a CertificateInput>,
65-
pub client_cert_path: Option<&'a CertificateInput>,
66-
pub client_key_path: Option<&'a CertificateInput>,
76+
pub root_cert: Option<&'a CertificateInput>,
77+
pub client_cert: Option<&'a CertificateInput>,
78+
pub client_key: Option<&'a CertificateInput>,
6779
}
6880

6981
pub async fn handshake<S, Ws>(

sqlx-core/src/net/tls/tls_native_tls.rs

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::io::{self, Read, Write};
22

33
use crate::io::ReadBuf;
44
use crate::net::tls::util::StdSocket;
5+
use crate::net::tls::RawTlsConfig;
56
use crate::net::tls::TlsConfig;
67
use crate::net::Socket;
78
use crate::rt;
@@ -39,36 +40,55 @@ impl<S: Socket> Socket for NativeTlsSocket<S> {
3940
}
4041
}
4142

42-
pub async fn handshake<S: Socket>(
43-
socket: S,
44-
config: TlsConfig<'_>,
45-
) -> crate::Result<NativeTlsSocket<S>> {
46-
let mut builder = native_tls::TlsConnector::builder();
47-
48-
builder
49-
.danger_accept_invalid_certs(config.accept_invalid_certs)
50-
.danger_accept_invalid_hostnames(config.accept_invalid_hostnames);
43+
impl TlsConfig<'_> {
44+
async fn native_tls_connector(&self) -> crate::Result<(native_tls::TlsConnector, &str), Error> {
45+
let TlsConfig::RawTlsConfig(RawTlsConfig {
46+
root_cert,
47+
client_cert,
48+
client_key,
49+
accept_invalid_certs,
50+
accept_invalid_hostnames,
51+
hostname,
52+
}) = self
53+
else {
54+
unreachable!()
55+
};
56+
let mut builder = native_tls::TlsConnector::builder();
57+
58+
builder
59+
.danger_accept_invalid_certs(*accept_invalid_certs)
60+
.danger_accept_invalid_hostnames(*accept_invalid_hostnames);
61+
62+
if let Some(root_cert) = root_cert {
63+
let data = root_cert.data().await?;
64+
builder.add_root_certificate(
65+
native_tls::Certificate::from_pem(&data).map_err(Error::tls)?,
66+
);
67+
}
5168

52-
if let Some(root_cert_path) = config.root_cert_path {
53-
let data = root_cert_path.data().await?;
54-
builder.add_root_certificate(native_tls::Certificate::from_pem(&data).map_err(Error::tls)?);
55-
}
69+
// authentication using user's key-file and its associated certificate
70+
if let (Some(cert), Some(key)) = (client_cert, client_key) {
71+
let cert = cert.data().await?;
72+
let key = key.data().await?;
73+
let identity = Identity::from_pkcs8(&cert, &key).map_err(Error::tls)?;
74+
builder.identity(identity);
75+
}
5676

57-
// authentication using user's key-file and its associated certificate
58-
if let (Some(cert_path), Some(key_path)) = (config.client_cert_path, config.client_key_path) {
59-
let cert_path = cert_path.data().await?;
60-
let key_path = key_path.data().await?;
61-
let identity = Identity::from_pkcs8(&cert_path, &key_path).map_err(Error::tls)?;
62-
builder.identity(identity);
77+
// The openssl TlsConnector synchronously loads certificates from files.
78+
// Loading these files can block for tens of milliseconds.
79+
let connector = rt::spawn_blocking(move || builder.build())
80+
.await
81+
.map_err(Error::tls)?;
82+
Ok((connector, hostname))
6383
}
84+
}
6485

65-
// The openssl TlsConnector synchronously loads certificates from files.
66-
// Loading these files can block for tens of milliseconds.
67-
let connector = rt::spawn_blocking(move || builder.build())
68-
.await
69-
.map_err(Error::tls)?;
70-
71-
let mut mid_handshake = match connector.connect(config.hostname, StdSocket::new(socket)) {
86+
pub async fn handshake<S: Socket>(
87+
socket: S,
88+
config: TlsConfig<'_>,
89+
) -> crate::Result<NativeTlsSocket<S>> {
90+
let (connector, hostname) = config.native_tls_connector().await?;
91+
let mut mid_handshake = match connector.connect(hostname, StdSocket::new(socket)) {
7292
Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }),
7393
Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)),
7494
Err(HandshakeError::WouldBlock(mid_handshake)) => mid_handshake,

sqlx-core/src/net/tls/tls_rustls.rs

Lines changed: 105 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use rustls::{
1919
use crate::error::Error;
2020
use crate::io::ReadBuf;
2121
use crate::net::tls::util::StdSocket;
22-
use crate::net::tls::TlsConfig;
22+
use crate::net::tls::{RawTlsConfig, TlsConfig};
2323
use crate::net::Socket;
2424

2525
pub struct RustlsSocket<S: Socket> {
@@ -87,100 +87,125 @@ impl<S: Socket> Socket for RustlsSocket<S> {
8787
}
8888
}
8989

90-
pub async fn handshake<S>(socket: S, tls_config: TlsConfig<'_>) -> Result<RustlsSocket<S>, Error>
91-
where
92-
S: Socket,
93-
{
94-
#[cfg(all(
95-
feature = "_tls-rustls-aws-lc-rs",
96-
not(feature = "_tls-rustls-ring-webpki"),
97-
not(feature = "_tls-rustls-ring-native-roots")
98-
))]
99-
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
100-
#[cfg(any(
101-
feature = "_tls-rustls-ring-webpki",
102-
feature = "_tls-rustls-ring-native-roots"
103-
))]
104-
let provider = Arc::new(rustls::crypto::ring::default_provider());
105-
106-
// Unwrapping is safe here because we use a default provider.
107-
let config = ClientConfig::builder_with_provider(provider.clone())
108-
.with_safe_default_protocol_versions()
109-
.unwrap();
110-
111-
// authentication using user's key and its associated certificate
112-
let user_auth = match (tls_config.client_cert_path, tls_config.client_key_path) {
113-
(Some(cert_path), Some(key_path)) => {
114-
let cert_chain = certs_from_pem(cert_path.data().await?)?;
115-
let key_der = private_key_from_pem(key_path.data().await?)?;
116-
Some((cert_chain, key_der))
117-
}
118-
(None, None) => None,
119-
(_, _) => {
120-
return Err(Error::Configuration(
121-
"user auth key and certs must be given together".into(),
122-
))
123-
}
124-
};
90+
impl TlsConfig<'_> {
91+
async fn rustls_config(&self) -> crate::Result<(rustls::ClientConfig, &str), Error> {
92+
let RawTlsConfig {
93+
accept_invalid_certs,
94+
accept_invalid_hostnames,
95+
hostname,
96+
root_cert,
97+
client_cert,
98+
client_key,
99+
} = match self {
100+
TlsConfig::RawTlsConfig(raw) => raw,
101+
TlsConfig::PrebuiltRustls { config, hostname } => {
102+
return Ok(((*config).to_owned(), hostname));
103+
}
104+
};
105+
106+
#[cfg(all(
107+
feature = "_tls-rustls-aws-lc-rs",
108+
not(feature = "_tls-rustls-ring-webpki"),
109+
not(feature = "_tls-rustls-ring-native-roots")
110+
))]
111+
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
112+
#[cfg(any(
113+
feature = "_tls-rustls-ring-webpki",
114+
feature = "_tls-rustls-ring-native-roots"
115+
))]
116+
let provider = Arc::new(rustls::crypto::ring::default_provider());
117+
118+
// Unwrapping is safe here because we use a default provider.
119+
let config = ClientConfig::builder_with_provider(provider.clone())
120+
.with_safe_default_protocol_versions()
121+
.unwrap();
122+
123+
// authentication using user's key and its associated certificate
124+
let user_auth = match (client_cert, client_key) {
125+
(Some(cert), Some(key)) => {
126+
let cert_chain = certs_from_pem(cert.data().await?)?;
127+
let key_der = private_key_from_pem(key.data().await?)?;
128+
Some((cert_chain, key_der))
129+
}
130+
(None, None) => None,
131+
(_, _) => {
132+
return Err(Error::Configuration(
133+
"user auth key and certs must be given together".into(),
134+
))
135+
}
136+
};
125137

126-
let config = if tls_config.accept_invalid_certs {
127-
if let Some(user_auth) = user_auth {
128-
config
129-
.dangerous()
130-
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
131-
.with_client_auth_cert(user_auth.0, user_auth.1)
132-
.map_err(Error::tls)?
138+
let config = if *accept_invalid_certs {
139+
if let Some(user_auth) = user_auth {
140+
config
141+
.dangerous()
142+
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
143+
.with_client_auth_cert(user_auth.0, user_auth.1)
144+
.map_err(Error::tls)?
145+
} else {
146+
config
147+
.dangerous()
148+
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
149+
.with_no_client_auth()
150+
}
133151
} else {
134-
config
135-
.dangerous()
136-
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
137-
.with_no_client_auth()
138-
}
139-
} else {
140-
let mut cert_store = import_root_certs();
152+
let mut cert_store = import_root_certs();
141153

142-
if let Some(ca) = tls_config.root_cert_path {
143-
let data = ca.data().await?;
154+
if let Some(ca) = root_cert {
155+
let data = ca.data().await?;
144156

145-
for result in CertificateDer::pem_slice_iter(&data) {
146-
let Ok(cert) = result else {
147-
return Err(Error::Tls(format!("Invalid certificate {ca}").into()));
148-
};
157+
for result in CertificateDer::pem_slice_iter(&data) {
158+
let Ok(cert) = result else {
159+
return Err(Error::Tls(format!("Invalid certificate {ca}").into()));
160+
};
149161

150-
cert_store.add(cert).map_err(|err| Error::Tls(err.into()))?;
162+
cert_store.add(cert).map_err(|err| Error::Tls(err.into()))?;
163+
}
151164
}
152-
}
153-
154-
if tls_config.accept_invalid_hostnames {
155-
let verifier = WebPkiServerVerifier::builder(Arc::new(cert_store))
156-
.build()
157-
.map_err(|err| Error::Tls(err.into()))?;
158165

159-
if let Some(user_auth) = user_auth {
166+
if *accept_invalid_hostnames {
167+
let verifier = WebPkiServerVerifier::builder(Arc::new(cert_store))
168+
.build()
169+
.map_err(|err| Error::Tls(err.into()))?;
170+
171+
if let Some(user_auth) = user_auth {
172+
config
173+
.dangerous()
174+
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier {
175+
verifier,
176+
}))
177+
.with_client_auth_cert(user_auth.0, user_auth.1)
178+
.map_err(Error::tls)?
179+
} else {
180+
config
181+
.dangerous()
182+
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier {
183+
verifier,
184+
}))
185+
.with_no_client_auth()
186+
}
187+
} else if let Some(user_auth) = user_auth {
160188
config
161-
.dangerous()
162-
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
189+
.with_root_certificates(cert_store)
163190
.with_client_auth_cert(user_auth.0, user_auth.1)
164191
.map_err(Error::tls)?
165192
} else {
166193
config
167-
.dangerous()
168-
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
194+
.with_root_certificates(cert_store)
169195
.with_no_client_auth()
170196
}
171-
} else if let Some(user_auth) = user_auth {
172-
config
173-
.with_root_certificates(cert_store)
174-
.with_client_auth_cert(user_auth.0, user_auth.1)
175-
.map_err(Error::tls)?
176-
} else {
177-
config
178-
.with_root_certificates(cert_store)
179-
.with_no_client_auth()
180-
}
181-
};
197+
};
198+
199+
Ok((config, hostname))
200+
}
201+
}
182202

183-
let host = ServerName::try_from(tls_config.hostname.to_owned()).map_err(Error::tls)?;
203+
pub async fn handshake<S>(socket: S, tls_config: TlsConfig<'_>) -> Result<RustlsSocket<S>, Error>
204+
where
205+
S: Socket,
206+
{
207+
let (config, hostname) = tls_config.rustls_config().await?;
208+
let host = ServerName::try_from(hostname.to_owned()).map_err(Error::tls)?;
184209

185210
let mut socket = RustlsSocket {
186211
inner: StdSocket::new(socket),

sqlx-mysql/src/connection/tls.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use sqlx_core::net::tls::RawTlsConfig;
2+
13
use crate::connection::{MySqlStream, Waiting};
24
use crate::error::Error;
35
use crate::net::tls::TlsConfig;
@@ -53,17 +55,17 @@ pub(super) async fn maybe_upgrade<S: Socket>(
5355
}
5456
}
5557

56-
let tls_config = TlsConfig {
58+
let tls_config = TlsConfig::RawTlsConfig(RawTlsConfig {
5759
accept_invalid_certs: !matches!(
5860
options.ssl_mode,
5961
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
6062
),
6163
accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity),
6264
hostname: &options.host,
63-
root_cert_path: options.ssl_ca.as_ref(),
64-
client_cert_path: options.ssl_client_cert.as_ref(),
65-
client_key_path: options.ssl_client_key.as_ref(),
66-
};
65+
root_cert: options.ssl_ca.as_ref(),
66+
client_cert: options.ssl_client_cert.as_ref(),
67+
client_key: options.ssl_client_key.as_ref(),
68+
});
6769

6870
// Request TLS upgrade
6971
stream.write_packet(SslRequest {

sqlx-postgres/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ any = ["sqlx-core/any"]
1414
json = ["sqlx-core/json"]
1515
migrate = ["sqlx-core/migrate"]
1616
offline = ["sqlx-core/offline"]
17+
rustls = ["dep:rustls"]
1718

1819
# Type Integration features
1920
bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"]
@@ -27,6 +28,9 @@ time = ["dep:time", "sqlx-core/time"]
2728
uuid = ["dep:uuid", "sqlx-core/uuid"]
2829

2930
[dependencies]
31+
# TLS
32+
rustls = { version = "0.23.24", default-features = false, features = ["std", "tls12"], optional = true }
33+
3034
# Futures crates
3135
futures-channel = { version = "0.3.19", default-features = false, features = ["sink", "alloc", "std"] }
3236
futures-core = { version = "0.3.19", default-features = false }

0 commit comments

Comments
 (0)