@@ -3,16 +3,23 @@ use std::{
33 net:: SocketAddr ,
44 os:: unix:: fs:: PermissionsExt ,
55 path:: { Path , PathBuf } ,
6+ sync:: Arc ,
7+ time:: Duration ,
68} ;
79
8- use rustls:: pki_types:: ServerName ;
10+ use rustls:: {
11+ pki_types:: { pem:: PemObject , ServerName } ,
12+ version:: TLS13 ,
13+ } ;
14+ use rustls_platform_verifier:: Verifier ;
915use serde:: Deserialize ;
16+ use tokio_rustls:: { TlsAcceptor , TlsConnector } ;
1017use tracing:: { info, warn} ;
1118
12- #[ derive( Deserialize , Debug ) ]
19+ #[ derive( Deserialize ) ]
1320#[ serde( rename_all = "kebab-case" , deny_unknown_fields) ]
1421pub struct Config {
15- pub nts_pool_ke_server : NtsPoolKeConfig ,
22+ pub server : NtsPoolKeConfig ,
1623 #[ serde( default ) ]
1724 pub observability : ObservabilityConfig ,
1825}
@@ -83,20 +90,112 @@ pub struct ObservabilityConfig {
8390
8491#[ derive( Debug , PartialEq , Eq , Clone , Deserialize ) ]
8592#[ serde( rename_all = "kebab-case" , deny_unknown_fields) ]
86- pub struct NtsPoolKeConfig {
87- pub certificate_authority_path : PathBuf ,
88- pub certificate_chain_path : PathBuf ,
89- pub private_key_path : PathBuf ,
93+ struct BareNtsPoolKeConfig {
94+ /// Additional CAs used to validate the certificates of upstream servers
95+ #[ serde( default ) ]
96+ upstream_cas : Option < PathBuf > ,
97+ /// Certificate chain for the key used by the server to identify itself during tls sessions
98+ certificate_chain : PathBuf ,
99+ /// Private key used by the server to identify itself during tls sessions
100+ private_key : PathBuf ,
90101 #[ serde( default = "default_nts_ke_timeout" ) ]
91- pub key_exchange_timeout_ms : u64 ,
92- pub listen : SocketAddr ,
93- pub key_exchange_servers : Vec < KeyExchangeServer > ,
102+ /// Timeout
103+ key_exchange_timeout : u64 ,
104+ /// Address for the server to listen on.
105+ listen : SocketAddr ,
106+ /// Which upstream servers to use.
107+ key_exchange_servers : Box < [ KeyExchangeServer ] > ,
94108}
95109
96110fn default_nts_ke_timeout ( ) -> u64 {
97111 1000
98112}
99113
114+ #[ derive( Clone ) ]
115+ pub struct NtsPoolKeConfig {
116+ pub server_tls : TlsAcceptor ,
117+ pub upstream_tls : TlsConnector ,
118+ pub listen : SocketAddr ,
119+ pub key_exchange_servers : Box < [ KeyExchangeServer ] > ,
120+ pub key_exchange_timeout : Duration ,
121+ }
122+
123+ fn load_certificates (
124+ path : impl AsRef < std:: path:: Path > ,
125+ ) -> Result < Vec < rustls:: pki_types:: CertificateDer < ' static > > , rustls:: pki_types:: pem:: Error > {
126+ rustls:: pki_types:: CertificateDer :: pem_file_iter ( path) ?. collect ( )
127+ }
128+
129+ impl < ' de > Deserialize < ' de > for NtsPoolKeConfig {
130+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
131+ where
132+ D : serde:: Deserializer < ' de > ,
133+ {
134+ let bare = BareNtsPoolKeConfig :: deserialize ( deserializer) ?;
135+
136+ let upstream_cas = bare
137+ . upstream_cas
138+ . map ( |path| {
139+ load_certificates ( & path) . map_err ( |e| {
140+ serde:: de:: Error :: custom ( format ! (
141+ "error reading additional upstream ca certificates from `{:?}`: {:?}" ,
142+ path, e
143+ ) )
144+ } )
145+ } )
146+ . transpose ( ) ?;
147+
148+ let certificate_chain = load_certificates ( & bare. certificate_chain ) . map_err ( |e| {
149+ serde:: de:: Error :: custom ( format ! (
150+ "error reading server's certificate chain from `{:?}`: {:?}" ,
151+ bare. certificate_chain, e
152+ ) )
153+ } ) ?;
154+
155+ let private_key = rustls:: pki_types:: PrivateKeyDer :: from_pem_file ( & bare. private_key )
156+ . map_err ( |e| {
157+ serde:: de:: Error :: custom ( format ! (
158+ "error reading server's private key from `{:?}`: {:?}" ,
159+ bare. private_key, e
160+ ) )
161+ } ) ?;
162+
163+ let mut server_config = rustls:: ServerConfig :: builder_with_protocol_versions ( & [ & TLS13 ] )
164+ . with_no_client_auth ( )
165+ . with_single_cert ( certificate_chain. clone ( ) , private_key. clone_key ( ) )
166+ . map_err ( serde:: de:: Error :: custom) ?;
167+ server_config. alpn_protocols . clear ( ) ;
168+ server_config. alpn_protocols . push ( b"ntske/1" . to_vec ( ) ) ;
169+
170+ let server_tls = TlsAcceptor :: from ( Arc :: new ( server_config) ) ;
171+
172+ let upstream_config_builder =
173+ rustls:: ClientConfig :: builder_with_protocol_versions ( & [ & TLS13 ] ) ;
174+ let provider = upstream_config_builder. crypto_provider ( ) . clone ( ) ;
175+ let verifier = match upstream_cas {
176+ Some ( upstream_cas) => Verifier :: new_with_extra_roots ( upstream_cas. iter ( ) . cloned ( ) )
177+ . map_err ( serde:: de:: Error :: custom) ?
178+ . with_provider ( provider) ,
179+ None => Verifier :: new ( ) ,
180+ } ;
181+
182+ let upstream_config = upstream_config_builder
183+ . dangerous ( )
184+ . with_custom_certificate_verifier ( Arc :: new ( verifier) )
185+ . with_client_auth_cert ( certificate_chain, private_key)
186+ . map_err ( serde:: de:: Error :: custom) ?;
187+ let upstream_tls = TlsConnector :: from ( Arc :: new ( upstream_config) ) ;
188+
189+ Ok ( NtsPoolKeConfig {
190+ server_tls,
191+ upstream_tls,
192+ listen : bare. listen ,
193+ key_exchange_servers : bare. key_exchange_servers ,
194+ key_exchange_timeout : std:: time:: Duration :: from_millis ( bare. key_exchange_timeout ) ,
195+ } )
196+ }
197+ }
198+
100199#[ derive( Debug , PartialEq , Eq , Clone ) ]
101200pub struct KeyExchangeServer {
102201 pub domain : String ,
@@ -135,17 +234,18 @@ impl<'de> Deserialize<'de> for KeyExchangeServer {
135234
136235#[ cfg( test) ]
137236mod tests {
237+ use std:: ops:: Deref ;
238+
138239 use super :: * ;
139240
140241 #[ test]
141- fn test_deserialize_nts_pool_ke ( ) {
142- let test: Config = toml:: from_str (
242+ fn test_deserialize_bare_config ( ) {
243+ let test: BareNtsPoolKeConfig = toml:: from_str (
143244 r#"
144- [nts-pool-ke-server]
145245 listen = "0.0.0.0:4460"
146- certificate-authority-path = "/foo/bar/ca.pem"
147- certificate-chain-path = "/foo/bar/baz.pem"
148- private-key-path = "spam.der"
246+ upstream-cas = "/foo/bar/ca.pem"
247+ certificate-chain = "/foo/bar/baz.pem"
248+ private-key = "spam.der"
149249 key-exchange-servers = [
150250 { domain = "foo.bar", port = 1234 },
151251 { domain = "bar.foo", port = 4321 },
@@ -155,23 +255,65 @@ mod tests {
155255 . unwrap ( ) ;
156256
157257 let ca = PathBuf :: from ( "/foo/bar/ca.pem" ) ;
158- assert_eq ! ( test. nts_pool_ke_server . certificate_authority_path , ca ) ;
258+ assert_eq ! ( test. upstream_cas , Some ( ca ) ) ;
159259
160260 let chain = PathBuf :: from ( "/foo/bar/baz.pem" ) ;
161- assert_eq ! ( test. nts_pool_ke_server . certificate_chain_path , chain) ;
261+ assert_eq ! ( test. certificate_chain , chain) ;
162262
163263 let private_key = PathBuf :: from ( "spam.der" ) ;
164- assert_eq ! ( test. nts_pool_ke_server. private_key_path, private_key) ;
264+ assert_eq ! ( test. private_key, private_key) ;
265+
266+ assert_eq ! ( test. key_exchange_timeout, 1000 , ) ;
267+ assert_eq ! ( test. listen, "0.0.0.0:4460" . parse( ) . unwrap( ) , ) ;
165268
166- assert_eq ! ( test. nts_pool_ke_server. key_exchange_timeout_ms, 1000 , ) ;
269+ assert_eq ! (
270+ test. key_exchange_servers. deref( ) ,
271+ [
272+ KeyExchangeServer {
273+ domain: String :: from( "foo.bar" ) ,
274+ server_name: ServerName :: try_from( "foo.bar" ) . unwrap( ) ,
275+ port: 1234
276+ } ,
277+ KeyExchangeServer {
278+ domain: String :: from( "bar.foo" ) ,
279+ server_name: ServerName :: try_from( "bar.foo" ) . unwrap( ) ,
280+ port: 4321
281+ } ,
282+ ]
283+ . as_slice( )
284+ ) ;
285+ }
286+
287+ #[ test]
288+ fn test_deserialize_config ( ) {
289+ let test: Config = toml:: from_str (
290+ r#"
291+ [server]
292+ listen = "0.0.0.0:4460"
293+ key-exchange-timeout = 500
294+ upstream-cas = "testdata/testca.pem"
295+ certificate-chain = "testdata/end.fullchain.pem"
296+ private-key = "testdata/end.key"
297+ key-exchange-servers = [
298+ { domain = "foo.bar", port = 1234 },
299+ { domain = "bar.foo", port = 4321 },
300+ ]
301+ "# ,
302+ )
303+ . unwrap ( ) ;
304+
305+ assert_eq ! (
306+ test. nts_pool_ke_server. key_exchange_timeout,
307+ Duration :: from_millis( 500 )
308+ ) ;
167309 assert_eq ! (
168310 test. nts_pool_ke_server. listen,
169311 "0.0.0.0:4460" . parse( ) . unwrap( ) ,
170312 ) ;
171313
172314 assert_eq ! (
173- test. nts_pool_ke_server. key_exchange_servers,
174- vec! [
315+ test. nts_pool_ke_server. key_exchange_servers. deref ( ) ,
316+ [
175317 KeyExchangeServer {
176318 domain: String :: from( "foo.bar" ) ,
177319 server_name: ServerName :: try_from( "foo.bar" ) . unwrap( ) ,
@@ -183,6 +325,7 @@ mod tests {
183325 port: 4321
184326 } ,
185327 ]
328+ . as_slice( )
186329 ) ;
187330 }
188331}
0 commit comments