Skip to content

Commit 47150fa

Browse files
authored
feat(tls): add LazyConfigAcceptor for rustls (#686)
1 parent 74070d2 commit 47150fa

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

compio-tls/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,8 @@ mod stream;
2323
pub use adapter::*;
2424
pub use maybe::*;
2525
pub use stream::*;
26+
27+
#[cfg(feature = "rustls")]
28+
mod rtls;
29+
#[cfg(feature = "rustls")]
30+
pub use rtls::*;

compio-tls/src/rtls.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use std::{
2+
io,
3+
pin::Pin,
4+
sync::Arc,
5+
task::{Context, Poll},
6+
};
7+
8+
use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream};
9+
use futures_util::FutureExt;
10+
use rustls::{
11+
ServerConfig, ServerConnection,
12+
server::{Acceptor, ClientHello},
13+
};
14+
15+
use crate::TlsStream;
16+
17+
/// A lazy TLS acceptor that performs the initial handshake and allows access to
18+
/// the [`ClientHello`] message before completing the handshake.
19+
pub struct LazyConfigAcceptor<S>(futures_rustls::LazyConfigAcceptor<AsyncStream<S>>);
20+
21+
impl<S: AsyncRead + AsyncWrite + 'static> LazyConfigAcceptor<S> {
22+
/// Create a new [`LazyConfigAcceptor`] with the given acceptor and stream.
23+
pub fn new(acceptor: Acceptor, s: S) -> Self {
24+
Self(futures_rustls::LazyConfigAcceptor::new(
25+
acceptor,
26+
AsyncStream::new(s),
27+
))
28+
}
29+
}
30+
31+
impl<S: AsyncRead + AsyncWrite + 'static> Future for LazyConfigAcceptor<S> {
32+
type Output = Result<StartHandshake<S>, io::Error>;
33+
34+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
35+
self.0.poll_unpin(cx).map_ok(StartHandshake)
36+
}
37+
}
38+
39+
/// A TLS acceptor that has completed the initial handshake and allows access to
40+
/// the [`ClientHello`] message.
41+
pub struct StartHandshake<S>(futures_rustls::StartHandshake<AsyncStream<S>>);
42+
43+
impl<S: AsyncRead + AsyncWrite + 'static> StartHandshake<S> {
44+
/// Get the [`ClientHello`] message from the initial handshake.
45+
pub fn client_hello(&self) -> ClientHello<'_> {
46+
self.0.client_hello()
47+
}
48+
49+
/// Complete the TLS handshake and return a [`TlsStream`] if successful.
50+
pub fn into_stream(
51+
self,
52+
config: Arc<ServerConfig>,
53+
) -> impl Future<Output = io::Result<TlsStream<S>>> {
54+
self.into_stream_with(config, |_| ())
55+
}
56+
57+
/// Complete the TLS handshake and return a [`TlsStream`] if successful.
58+
pub fn into_stream_with<F>(
59+
self,
60+
config: Arc<ServerConfig>,
61+
f: F,
62+
) -> impl Future<Output = io::Result<TlsStream<S>>>
63+
where
64+
F: FnOnce(&mut ServerConnection),
65+
{
66+
self.0
67+
.into_stream_with(config, f)
68+
.map(|res| res.map(TlsStream::from))
69+
}
70+
}

0 commit comments

Comments
 (0)