Skip to content

Commit 59069b1

Browse files
committed
Moved tls config generation and type conversions into the config parsing, and made settings names more descriptive.
1 parent 70b6841 commit 59069b1

File tree

8 files changed

+321
-115
lines changed

8 files changed

+321
-115
lines changed

src/config.rs

Lines changed: 165 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,23 @@ use std::{
33
net::SocketAddr,
44
os::unix::fs::PermissionsExt,
55
path::{Path, PathBuf},
6+
sync::Arc,
7+
time::Duration,
68
};
79

8-
use rustls::pki_types::ServerName;
10+
use rustls::{
11+
pki_types::{pem::PemObject, ServerName},
12+
version::TLS13,
13+
};
14+
use rustls_platform_verifier::Verifier;
915
use serde::Deserialize;
16+
use tokio_rustls::{TlsAcceptor, TlsConnector};
1017
use tracing::{info, warn};
1118

12-
#[derive(Deserialize, Debug)]
19+
#[derive(Deserialize)]
1320
#[serde(rename_all = "kebab-case", deny_unknown_fields)]
1421
pub struct Config {
15-
pub nts_pool_ke_server: NtsPoolKeConfig,
22+
pub server: NtsPoolKeConfig,
1623
#[serde(default)]
1724
pub observability: ObservabilityConfig,
1825
}
@@ -83,20 +90,112 @@ pub struct ObservabilityConfig {
8390

8491
#[derive(Debug, PartialEq, Eq, Clone, Deserialize)]
8592
#[serde(rename_all = "kebab-case", deny_unknown_fields)]
86-
pub struct NtsPoolKeConfig {
87-
pub certificate_authority_path: PathBuf,
88-
pub certificate_chain_path: PathBuf,
89-
pub private_key_path: PathBuf,
93+
struct BareNtsPoolKeConfig {
94+
/// Additional CAs used to validate the certificates of upstream servers
95+
#[serde(default)]
96+
upstream_cas: Option<PathBuf>,
97+
/// Certificate chain for the key used by the server to identify itself during tls sessions
98+
certificate_chain: PathBuf,
99+
/// Private key used by the server to identify itself during tls sessions
100+
private_key: PathBuf,
90101
#[serde(default = "default_nts_ke_timeout")]
91-
pub key_exchange_timeout_ms: u64,
92-
pub listen: SocketAddr,
93-
pub key_exchange_servers: Vec<KeyExchangeServer>,
102+
/// Timeout
103+
key_exchange_timeout: u64,
104+
/// Address for the server to listen on.
105+
listen: SocketAddr,
106+
/// Which upstream servers to use.
107+
key_exchange_servers: Box<[KeyExchangeServer]>,
94108
}
95109

96110
fn default_nts_ke_timeout() -> u64 {
97111
1000
98112
}
99113

114+
#[derive(Clone)]
115+
pub struct NtsPoolKeConfig {
116+
pub server_tls: TlsAcceptor,
117+
pub upstream_tls: TlsConnector,
118+
pub listen: SocketAddr,
119+
pub key_exchange_servers: Box<[KeyExchangeServer]>,
120+
pub key_exchange_timeout: Duration,
121+
}
122+
123+
fn load_certificates(
124+
path: impl AsRef<std::path::Path>,
125+
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, rustls::pki_types::pem::Error> {
126+
rustls::pki_types::CertificateDer::pem_file_iter(path)?.collect()
127+
}
128+
129+
impl<'de> Deserialize<'de> for NtsPoolKeConfig {
130+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131+
where
132+
D: serde::Deserializer<'de>,
133+
{
134+
let bare = BareNtsPoolKeConfig::deserialize(deserializer)?;
135+
136+
let upstream_cas = bare
137+
.upstream_cas
138+
.map(|path| {
139+
load_certificates(&path).map_err(|e| {
140+
serde::de::Error::custom(format!(
141+
"error reading additional upstream ca certificates from `{:?}`: {:?}",
142+
path, e
143+
))
144+
})
145+
})
146+
.transpose()?;
147+
148+
let certificate_chain = load_certificates(&bare.certificate_chain).map_err(|e| {
149+
serde::de::Error::custom(format!(
150+
"error reading server's certificate chain from `{:?}`: {:?}",
151+
bare.certificate_chain, e
152+
))
153+
})?;
154+
155+
let private_key = rustls::pki_types::PrivateKeyDer::from_pem_file(&bare.private_key)
156+
.map_err(|e| {
157+
serde::de::Error::custom(format!(
158+
"error reading server's private key from `{:?}`: {:?}",
159+
bare.private_key, e
160+
))
161+
})?;
162+
163+
let mut server_config = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13])
164+
.with_no_client_auth()
165+
.with_single_cert(certificate_chain.clone(), private_key.clone_key())
166+
.map_err(serde::de::Error::custom)?;
167+
server_config.alpn_protocols.clear();
168+
server_config.alpn_protocols.push(b"ntske/1".to_vec());
169+
170+
let server_tls = TlsAcceptor::from(Arc::new(server_config));
171+
172+
let upstream_config_builder =
173+
rustls::ClientConfig::builder_with_protocol_versions(&[&TLS13]);
174+
let provider = upstream_config_builder.crypto_provider().clone();
175+
let verifier = match upstream_cas {
176+
Some(upstream_cas) => Verifier::new_with_extra_roots(upstream_cas.iter().cloned())
177+
.map_err(serde::de::Error::custom)?
178+
.with_provider(provider),
179+
None => Verifier::new(),
180+
};
181+
182+
let upstream_config = upstream_config_builder
183+
.dangerous()
184+
.with_custom_certificate_verifier(Arc::new(verifier))
185+
.with_client_auth_cert(certificate_chain, private_key)
186+
.map_err(serde::de::Error::custom)?;
187+
let upstream_tls = TlsConnector::from(Arc::new(upstream_config));
188+
189+
Ok(NtsPoolKeConfig {
190+
server_tls,
191+
upstream_tls,
192+
listen: bare.listen,
193+
key_exchange_servers: bare.key_exchange_servers,
194+
key_exchange_timeout: std::time::Duration::from_millis(bare.key_exchange_timeout),
195+
})
196+
}
197+
}
198+
100199
#[derive(Debug, PartialEq, Eq, Clone)]
101200
pub struct KeyExchangeServer {
102201
pub domain: String,
@@ -135,17 +234,18 @@ impl<'de> Deserialize<'de> for KeyExchangeServer {
135234

136235
#[cfg(test)]
137236
mod tests {
237+
use std::ops::Deref;
238+
138239
use super::*;
139240

140241
#[test]
141-
fn test_deserialize_nts_pool_ke() {
142-
let test: Config = toml::from_str(
242+
fn test_deserialize_bare_config() {
243+
let test: BareNtsPoolKeConfig = toml::from_str(
143244
r#"
144-
[nts-pool-ke-server]
145245
listen = "0.0.0.0:4460"
146-
certificate-authority-path = "/foo/bar/ca.pem"
147-
certificate-chain-path = "/foo/bar/baz.pem"
148-
private-key-path = "spam.der"
246+
upstream-cas = "/foo/bar/ca.pem"
247+
certificate-chain = "/foo/bar/baz.pem"
248+
private-key = "spam.der"
149249
key-exchange-servers = [
150250
{ domain = "foo.bar", port = 1234 },
151251
{ domain = "bar.foo", port = 4321 },
@@ -155,23 +255,65 @@ mod tests {
155255
.unwrap();
156256

157257
let ca = PathBuf::from("/foo/bar/ca.pem");
158-
assert_eq!(test.nts_pool_ke_server.certificate_authority_path, ca);
258+
assert_eq!(test.upstream_cas, Some(ca));
159259

160260
let chain = PathBuf::from("/foo/bar/baz.pem");
161-
assert_eq!(test.nts_pool_ke_server.certificate_chain_path, chain);
261+
assert_eq!(test.certificate_chain, chain);
162262

163263
let private_key = PathBuf::from("spam.der");
164-
assert_eq!(test.nts_pool_ke_server.private_key_path, private_key);
264+
assert_eq!(test.private_key, private_key);
265+
266+
assert_eq!(test.key_exchange_timeout, 1000,);
267+
assert_eq!(test.listen, "0.0.0.0:4460".parse().unwrap(),);
165268

166-
assert_eq!(test.nts_pool_ke_server.key_exchange_timeout_ms, 1000,);
269+
assert_eq!(
270+
test.key_exchange_servers.deref(),
271+
[
272+
KeyExchangeServer {
273+
domain: String::from("foo.bar"),
274+
server_name: ServerName::try_from("foo.bar").unwrap(),
275+
port: 1234
276+
},
277+
KeyExchangeServer {
278+
domain: String::from("bar.foo"),
279+
server_name: ServerName::try_from("bar.foo").unwrap(),
280+
port: 4321
281+
},
282+
]
283+
.as_slice()
284+
);
285+
}
286+
287+
#[test]
288+
fn test_deserialize_config() {
289+
let test: Config = toml::from_str(
290+
r#"
291+
[server]
292+
listen = "0.0.0.0:4460"
293+
key-exchange-timeout = 500
294+
upstream-cas = "testdata/testca.pem"
295+
certificate-chain = "testdata/end.fullchain.pem"
296+
private-key = "testdata/end.key"
297+
key-exchange-servers = [
298+
{ domain = "foo.bar", port = 1234 },
299+
{ domain = "bar.foo", port = 4321 },
300+
]
301+
"#,
302+
)
303+
.unwrap();
304+
305+
assert_eq!(
306+
test.nts_pool_ke_server.key_exchange_timeout,
307+
Duration::from_millis(500)
308+
);
167309
assert_eq!(
168310
test.nts_pool_ke_server.listen,
169311
"0.0.0.0:4460".parse().unwrap(),
170312
);
171313

172314
assert_eq!(
173-
test.nts_pool_ke_server.key_exchange_servers,
174-
vec![
315+
test.nts_pool_ke_server.key_exchange_servers.deref(),
316+
[
175317
KeyExchangeServer {
176318
domain: String::from("foo.bar"),
177319
server_name: ServerName::try_from("foo.bar").unwrap(),
@@ -183,6 +325,7 @@ mod tests {
183325
port: 4321
184326
},
185327
]
328+
.as_slice()
186329
);
187330
}
188331
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async fn run(options: NtsPoolKeOptions) -> Result<(), Box<dyn std::error::Error>
9999
// tracing setup to ensure logging is fully configured.
100100
config.check();
101101

102-
let result = run_nts_pool_ke(config.nts_pool_ke_server).await;
102+
let result = run_nts_pool_ke(config.server).await;
103103

104104
match result {
105105
Ok(v) => Ok(v),

0 commit comments

Comments
 (0)