Skip to content

Commit 68594fd

Browse files
committed
Add support for custom TLS acceptors
Allow the configuration to provide a custom object to accept a TlsStream from a TcpStream. This gives total control over the TLS negotiation process, such as to process some TLS streams or ALPN negotiations internally without passing them through to tide. In particular, this allows responding to ACME tls-alpn-01 challenges.
1 parent 635ecf0 commit 68594fd

File tree

5 files changed

+83
-16
lines changed

5 files changed

+83
-16
lines changed

src/custom_tls_acceptor.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use async_rustls::server::TlsStream;
2+
use async_std::net::TcpStream;
3+
4+
/// The CustomTlsAcceptor trait provides a custom implementation of accepting
5+
/// TLS connections from a [`TcpStream`]. tide-rustls will call the
6+
/// [`CustomTlsAcceptor::accept`] function for each new [`TcpStream`] it
7+
/// accepts, to obtain a [`TlsStream`]).
8+
///
9+
/// Implementing this trait gives you control over the TLS negotiation process,
10+
/// and allows you to process some TLS connections internally without passing
11+
/// them through to tide, such as for multiplexing or custom ALPN negotiation.
12+
#[tide::utils::async_trait]
13+
pub trait CustomTlsAcceptor: Send + Sync {
14+
/// Accept a [`TlsStream`] from a [`TcpStream`].
15+
///
16+
/// If TLS negotiation succeeds, but does not result in a stream that tide
17+
/// should process HTTP connections from, return `Ok(None)`.
18+
async fn accept(&self, stream: TcpStream) -> std::io::Result<Option<TlsStream<TcpStream>>>;
19+
}
20+
21+
/// Crate-private adapter to make `async_rustls::TlsAcceptor` implement
22+
/// `CustomTlsAcceptor`, without creating a conflict between the two `accept`
23+
/// methods.
24+
pub(crate) struct StandardTlsAcceptor(pub(crate) async_rustls::TlsAcceptor);
25+
26+
#[tide::utils::async_trait]
27+
impl CustomTlsAcceptor for StandardTlsAcceptor {
28+
async fn accept(&self, stream: TcpStream) -> std::io::Result<Option<TlsStream<TcpStream>>> {
29+
self.0.accept(stream).await.map(Some)
30+
}
31+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
unused_qualifications
2929
)]
3030

31+
mod custom_tls_acceptor;
3132
mod tcp_connection;
3233
mod tls_listener;
3334
mod tls_listener_builder;
@@ -38,6 +39,7 @@ pub(crate) use tcp_connection::TcpConnection;
3839
pub(crate) use tls_listener_config::TlsListenerConfig;
3940
pub(crate) use tls_stream_wrapper::TlsStreamWrapper;
4041

42+
pub use custom_tls_acceptor::CustomTlsAcceptor;
4143
pub use tls_listener::TlsListener;
4244
pub use tls_listener_builder::TlsListenerBuilder;
4345

src/tls_listener.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::{TcpConnection, TlsListenerBuilder, TlsListenerConfig, TlsStreamWrapper};
1+
use crate::custom_tls_acceptor::StandardTlsAcceptor;
2+
use crate::{
3+
CustomTlsAcceptor, TcpConnection, TlsListenerBuilder, TlsListenerConfig, TlsStreamWrapper,
4+
};
25

36
use tide::listener::ListenInfo;
47
use tide::listener::{Listener, ToListener};
@@ -79,12 +82,14 @@ impl<State> TlsListener<State> {
7982
.set_single_cert(certs, keys.remove(0))
8083
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
8184

82-
TlsListenerConfig::Acceptor(TlsAcceptor::from(Arc::new(config)))
85+
TlsListenerConfig::Acceptor(Arc::new(StandardTlsAcceptor(TlsAcceptor::from(
86+
Arc::new(config),
87+
))))
8388
}
8489

85-
TlsListenerConfig::ServerConfig(config) => {
86-
TlsListenerConfig::Acceptor(TlsAcceptor::from(Arc::new(config)))
87-
}
90+
TlsListenerConfig::ServerConfig(config) => TlsListenerConfig::Acceptor(Arc::new(
91+
StandardTlsAcceptor(TlsAcceptor::from(Arc::new(config))),
92+
)),
8893

8994
other @ TlsListenerConfig::Acceptor(_) => other,
9095

@@ -99,7 +104,7 @@ impl<State> TlsListener<State> {
99104
Ok(())
100105
}
101106

102-
fn acceptor(&self) -> Option<&TlsAcceptor> {
107+
fn acceptor(&self) -> Option<&Arc<dyn CustomTlsAcceptor>> {
103108
match self.config {
104109
TlsListenerConfig::Acceptor(ref a) => Some(a),
105110
_ => None,
@@ -125,14 +130,16 @@ impl<State> TlsListener<State> {
125130
fn handle_tls<State: Clone + Send + Sync + 'static>(
126131
app: Server<State>,
127132
stream: TcpStream,
128-
acceptor: TlsAcceptor,
133+
acceptor: Arc<dyn CustomTlsAcceptor>,
129134
) {
130135
task::spawn(async move {
131136
let local_addr = stream.local_addr().ok();
132137
let peer_addr = stream.peer_addr().ok();
133138

134139
match acceptor.accept(stream).await {
135-
Ok(tls_stream) => {
140+
Ok(None) => {}
141+
142+
Ok(Some(tls_stream)) => {
136143
let stream = TlsStreamWrapper::new(tls_stream);
137144
let fut = async_h1::accept(stream, |mut req| async {
138145
if req.url_mut().set_scheme("https").is_err() {

src/tls_listener_builder.rs

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ use async_std::net::TcpListener;
33

44
use rustls::ServerConfig;
55

6-
use super::{TcpConnection, TlsListener, TlsListenerConfig};
6+
use super::{CustomTlsAcceptor, TcpConnection, TlsListener, TlsListenerConfig};
77

88
use std::marker::PhantomData;
99
use std::net::{SocketAddr, ToSocketAddrs};
1010
use std::path::{Path, PathBuf};
11+
use std::sync::Arc;
1112

1213
/// # A builder for TlsListeners
1314
///
@@ -38,6 +39,7 @@ pub struct TlsListenerBuilder<State> {
3839
key: Option<PathBuf>,
3940
cert: Option<PathBuf>,
4041
config: Option<ServerConfig>,
42+
tls_acceptor: Option<Arc<dyn CustomTlsAcceptor>>,
4143
tcp: Option<TcpListener>,
4244
addrs: Option<Vec<SocketAddr>>,
4345
_state: PhantomData<State>,
@@ -49,6 +51,7 @@ impl<State> Default for TlsListenerBuilder<State> {
4951
key: None,
5052
cert: None,
5153
config: None,
54+
tls_acceptor: None,
5255
tcp: None,
5356
addrs: None,
5457
_state: PhantomData,
@@ -69,6 +72,14 @@ impl<State> std::fmt::Debug for TlsListenerBuilder<State> {
6972
"None"
7073
},
7174
)
75+
.field(
76+
"tls_acceptor",
77+
&if self.tls_acceptor.is_some() {
78+
"Some(_)"
79+
} else {
80+
"None"
81+
},
82+
)
7283
.field("tcp", &self.tcp)
7384
.field("addrs", &self.addrs)
7485
.finish()
@@ -108,6 +119,17 @@ impl<State> TlsListenerBuilder<State> {
108119
self
109120
}
110121

122+
/// Provides a custom acceptor for TLS connections. This is mutually
123+
/// exclusive with any of [`TlsListenerBuilder::key`],
124+
/// [`TlsListenerBuilder::cert`], and [`TlsListenerBuilder::config`], but
125+
/// gives total control over accepting TLS connections, including
126+
/// multiplexing other streams or ALPN negotiations on the same TLS
127+
/// connection that tide should ignore.
128+
pub fn tls_acceptor(mut self, acceptor: Arc<dyn CustomTlsAcceptor>) -> Self {
129+
self.tls_acceptor = Some(acceptor);
130+
self
131+
}
132+
111133
/// Provides a bound tcp listener (either async-std or std) to
112134
/// build this tls listener on. This is mutually exclusive with
113135
/// [`TlsListenerBuilder::addrs`], but one of them is mandatory.
@@ -134,26 +156,29 @@ impl<State> TlsListenerBuilder<State> {
134156
/// * either of these is provided, but not both
135157
/// * [`TlsListenerBuilder::tcp`]
136158
/// * [`TlsListenerBuilder::addrs`]
137-
/// * either of these is provided, but not both
159+
/// * exactly one of these is provided
138160
/// * both [`TlsListenerBuilder::cert`] AND [`TlsListenerBuilder::key`]
139161
/// * [`TlsListenerBuilder::config`]
162+
/// * [`TlsListenerBuilder::tls_acceptor`]
140163
pub fn finish(self) -> io::Result<TlsListener<State>> {
141164
let Self {
142165
key,
143166
cert,
144167
config,
168+
tls_acceptor,
145169
tcp,
146170
addrs,
147171
..
148172
} = self;
149173

150-
let config = match (key, cert, config) {
151-
(Some(key), Some(cert), None) => TlsListenerConfig::Paths { key, cert },
152-
(None, None, Some(config)) => TlsListenerConfig::ServerConfig(config),
174+
let config = match (key, cert, config, tls_acceptor) {
175+
(Some(key), Some(cert), None, None) => TlsListenerConfig::Paths { key, cert },
176+
(None, None, Some(config), None) => TlsListenerConfig::ServerConfig(config),
177+
(None, None, None, Some(tls_acceptor)) => TlsListenerConfig::Acceptor(tls_acceptor),
153178
_ => {
154179
return Err(io::Error::new(
155180
io::ErrorKind::InvalidInput,
156-
"either cert + key are required or a ServerConfig",
181+
"need exactly one of cert + key, ServerConfig, or TLS acceptor",
157182
))
158183
}
159184
};

src/tls_listener_config.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use std::fmt::{self, Debug, Formatter};
22

3-
use async_rustls::TlsAcceptor;
43
use rustls::ServerConfig;
54

5+
use super::CustomTlsAcceptor;
6+
67
use std::path::PathBuf;
8+
use std::sync::Arc;
79

810
impl Default for TlsListenerConfig {
911
fn default() -> Self {
@@ -12,7 +14,7 @@ impl Default for TlsListenerConfig {
1214
}
1315
pub(crate) enum TlsListenerConfig {
1416
Unconfigured,
15-
Acceptor(TlsAcceptor),
17+
Acceptor(Arc<dyn CustomTlsAcceptor>),
1618
ServerConfig(ServerConfig),
1719
Paths { cert: PathBuf, key: PathBuf },
1820
}

0 commit comments

Comments
 (0)