Skip to content

Commit fe3be3c

Browse files
committed
Add AnyStream/AnyListener types to combine TCP and TLS streams/listeners.
Add `Accept` trait to abstract over different listener types.
1 parent cbd464e commit fe3be3c

File tree

10 files changed

+390
-40
lines changed

10 files changed

+390
-40
lines changed

src/net/common.rs

Lines changed: 223 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
//! Common types and traits shared between client and server TLS implementations.
22
3+
use std::any::TypeId;
34
use std::borrow::Cow;
5+
use std::pin::Pin;
6+
use std::task::{Context, Poll};
47
use std::{fmt, io};
58

6-
use mlua::{IntoLua, Lua, Result, Value};
9+
use mlua::{AnyUserData, Error, FromLua, IntoLua, Lua, MaybeSend, Result, Value};
10+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11+
12+
use super::tcp::{TcpListener, TcpStream};
13+
#[cfg(feature = "tls")]
14+
use super::tls::{TlsListener, TlsStream};
15+
#[cfg(unix)]
16+
use super::unix::{UnixListener, UnixStream};
717

818
/// Socket address that can be either TCP or Unix domain socket.
919
pub enum AnySocketAddr {
@@ -60,3 +70,215 @@ impl AddressProvider for tokio::net::UnixStream {
6070
Ok(AnySocketAddr::Unix(self.peer_addr()?))
6171
}
6272
}
73+
74+
/// A stream that can be either TCP or Unix domain socket, possibly wrapped in TLS.
75+
pub enum AnyStream {
76+
Tcp(TcpStream),
77+
#[cfg(unix)]
78+
Unix(UnixStream),
79+
#[cfg(feature = "tls")]
80+
TcpTls(TlsStream<TcpStream>),
81+
#[cfg(all(unix, feature = "tls"))]
82+
UnixTls(TlsStream<UnixStream>),
83+
}
84+
85+
impl AsyncRead for AnyStream {
86+
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
87+
match self.get_mut() {
88+
AnyStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
89+
#[cfg(unix)]
90+
AnyStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
91+
#[cfg(feature = "tls")]
92+
AnyStream::TcpTls(s) => Pin::new(s).poll_read(cx, buf),
93+
#[cfg(all(unix, feature = "tls"))]
94+
AnyStream::UnixTls(s) => Pin::new(s).poll_read(cx, buf),
95+
}
96+
}
97+
}
98+
99+
impl AsyncWrite for AnyStream {
100+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
101+
match self.get_mut() {
102+
AnyStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
103+
#[cfg(unix)]
104+
AnyStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
105+
#[cfg(feature = "tls")]
106+
AnyStream::TcpTls(s) => Pin::new(s).poll_write(cx, buf),
107+
#[cfg(all(unix, feature = "tls"))]
108+
AnyStream::UnixTls(s) => Pin::new(s).poll_write(cx, buf),
109+
}
110+
}
111+
112+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113+
match self.get_mut() {
114+
AnyStream::Tcp(s) => Pin::new(s).poll_flush(cx),
115+
#[cfg(unix)]
116+
AnyStream::Unix(s) => Pin::new(s).poll_flush(cx),
117+
#[cfg(feature = "tls")]
118+
AnyStream::TcpTls(s) => Pin::new(s).poll_flush(cx),
119+
#[cfg(all(unix, feature = "tls"))]
120+
AnyStream::UnixTls(s) => Pin::new(s).poll_flush(cx),
121+
}
122+
}
123+
124+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
125+
match self.get_mut() {
126+
AnyStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
127+
#[cfg(unix)]
128+
AnyStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
129+
#[cfg(feature = "tls")]
130+
AnyStream::TcpTls(s) => Pin::new(s).poll_shutdown(cx),
131+
#[cfg(all(unix, feature = "tls"))]
132+
AnyStream::UnixTls(s) => Pin::new(s).poll_shutdown(cx),
133+
}
134+
}
135+
}
136+
137+
impl IntoLua for AnyStream {
138+
fn into_lua(self, lua: &Lua) -> Result<Value> {
139+
match self {
140+
AnyStream::Tcp(s) => lua.create_userdata(s).map(Value::UserData),
141+
#[cfg(unix)]
142+
AnyStream::Unix(s) => lua.create_userdata(s).map(Value::UserData),
143+
#[cfg(feature = "tls")]
144+
AnyStream::TcpTls(s) => lua.create_userdata(s).map(Value::UserData),
145+
#[cfg(all(unix, feature = "tls"))]
146+
AnyStream::UnixTls(s) => lua.create_userdata(s).map(Value::UserData),
147+
}
148+
}
149+
}
150+
151+
impl FromLua for AnyStream {
152+
fn from_lua(value: Value, lua: &Lua) -> Result<Self> {
153+
let value = lua.unpack::<AnyUserData>(value)?;
154+
match value.type_id() {
155+
Some(id) if id == TypeId::of::<TcpStream>() => {
156+
let stream = value.take::<TcpStream>()?;
157+
Ok(AnyStream::Tcp(stream))
158+
}
159+
#[cfg(unix)]
160+
Some(id) if id == TypeId::of::<UnixStream>() => {
161+
let stream = value.take::<UnixStream>()?;
162+
Ok(AnyStream::Unix(stream))
163+
}
164+
#[cfg(feature = "tls")]
165+
Some(id) if id == TypeId::of::<TlsStream<TcpStream>>() => {
166+
let stream = value.take::<TlsStream<TcpStream>>()?;
167+
Ok(AnyStream::TcpTls(stream))
168+
}
169+
#[cfg(all(unix, feature = "tls"))]
170+
Some(id) if id == TypeId::of::<TlsStream<UnixStream>>() => {
171+
let stream = value.take::<TlsStream<UnixStream>>()?;
172+
Ok(AnyStream::UnixTls(stream))
173+
}
174+
_ => {
175+
let type_name = value.type_name().ok().flatten();
176+
let type_name = type_name.as_deref().unwrap_or("unknown");
177+
Err(Error::FromLuaConversionError {
178+
from: "UserData",
179+
to: "AnyStream".to_string(),
180+
message: Some(format!("expected TcpStream or UnixStream, got {type_name}",)),
181+
})
182+
}
183+
}
184+
}
185+
}
186+
187+
/// A listener that can be either TCP or Unix domain socket, possibly wrapped in TLS.
188+
pub enum AnyListener {
189+
Tcp(TcpListener),
190+
#[cfg(unix)]
191+
Unix(UnixListener),
192+
#[cfg(feature = "tls")]
193+
TcpTls(TlsListener<TcpListener>),
194+
#[cfg(all(unix, feature = "tls"))]
195+
UnixTls(TlsListener<UnixListener>),
196+
}
197+
198+
/// Trait for accepting incoming connections from various listener types.
199+
pub trait Accept {
200+
type Stream: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static;
201+
202+
/// Get the local address that the listener is bound to.
203+
fn local_addr(&self) -> io::Result<AnySocketAddr>;
204+
205+
/// Accept an incoming connection.
206+
#[allow(async_fn_in_trait)]
207+
async fn accept(&self) -> io::Result<(Self::Stream, AnySocketAddr)>;
208+
}
209+
210+
impl Accept for AnyListener {
211+
type Stream = AnyStream;
212+
213+
fn local_addr(&self) -> io::Result<AnySocketAddr> {
214+
match self {
215+
AnyListener::Tcp(l) => l.local_addr(),
216+
#[cfg(unix)]
217+
AnyListener::Unix(l) => l.local_addr(),
218+
#[cfg(feature = "tls")]
219+
AnyListener::TcpTls(l) => l.local_addr(),
220+
#[cfg(all(unix, feature = "tls"))]
221+
AnyListener::UnixTls(l) => l.local_addr(),
222+
}
223+
}
224+
225+
async fn accept(&self) -> io::Result<(AnyStream, AnySocketAddr)> {
226+
match self {
227+
AnyListener::Tcp(listener) => {
228+
let (stream, addr) = listener.accept().await?;
229+
Ok((AnyStream::Tcp(stream), addr))
230+
}
231+
#[cfg(unix)]
232+
AnyListener::Unix(listener) => {
233+
let (stream, addr) = listener.accept().await?;
234+
Ok((AnyStream::Unix(stream), addr))
235+
}
236+
#[cfg(feature = "tls")]
237+
AnyListener::TcpTls(listener) => {
238+
let (stream, addr) = listener.accept().await?;
239+
Ok((AnyStream::TcpTls(stream), addr))
240+
}
241+
#[cfg(all(unix, feature = "tls"))]
242+
AnyListener::UnixTls(listener) => {
243+
let (stream, addr) = listener.accept().await?;
244+
Ok((AnyStream::UnixTls(stream), addr))
245+
}
246+
}
247+
}
248+
}
249+
250+
impl FromLua for AnyListener {
251+
fn from_lua(value: Value, lua: &Lua) -> Result<Self> {
252+
let value = lua.unpack::<AnyUserData>(value)?;
253+
match value.type_id() {
254+
Some(id) if id == TypeId::of::<TcpListener>() => {
255+
let listener = value.take::<TcpListener>()?;
256+
Ok(AnyListener::Tcp(listener))
257+
}
258+
#[cfg(unix)]
259+
Some(id) if id == TypeId::of::<UnixListener>() => {
260+
let listener = value.take::<UnixListener>()?;
261+
Ok(AnyListener::Unix(listener))
262+
}
263+
#[cfg(feature = "tls")]
264+
Some(id) if id == TypeId::of::<TlsListener<TcpListener>>() => {
265+
let listener = value.take::<TlsListener<TcpListener>>()?;
266+
Ok(AnyListener::TcpTls(listener))
267+
}
268+
#[cfg(all(unix, feature = "tls"))]
269+
Some(id) if id == TypeId::of::<TlsListener<UnixListener>>() => {
270+
let listener = value.take::<TlsListener<UnixListener>>()?;
271+
Ok(AnyListener::UnixTls(listener))
272+
}
273+
_ => {
274+
let type_name = value.type_name().ok().flatten();
275+
let type_name = type_name.as_deref().unwrap_or("unknown");
276+
Err(Error::FromLuaConversionError {
277+
from: "UserData",
278+
to: "AnyListener".to_string(),
279+
message: Some(format!("expected TcpListener or UnixListener, got {type_name}")),
280+
})
281+
}
282+
}
283+
}
284+
}

src/net/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use mlua::{Lua, Result, Table};
22

3-
pub use common::{AddressProvider, AnySocketAddr};
3+
pub use common::{AddressProvider, AnyListener, AnySocketAddr, AnyStream};
4+
pub use tcp::{TcpListener, TcpStream};
45

56
/// A loader for the `net` module.
67
fn loader(lua: &Lua) -> Result<Table> {
@@ -27,7 +28,7 @@ macro_rules! with_io_timeout {
2728
};
2829
}
2930

30-
mod common;
31+
pub(crate) mod common;
3132

3233
pub mod tcp;
3334
#[cfg(feature = "tls")]

src/net/tcp/listener.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,23 @@ use mlua::{Lua, Result, Table, UserData, UserDataMethods, UserDataRegistry};
66
use tokio::net::lookup_host;
77

88
use super::{SocketOptions, TcpSocket, TcpStream};
9-
use crate::net::common::AnySocketAddr;
9+
use crate::net::common::{Accept, AnySocketAddr};
1010

1111
pub struct TcpListener(pub(crate) tokio::net::TcpListener);
1212

13-
impl TcpListener {
14-
pub(crate) fn local_addr(&self) -> io::Result<AnySocketAddr> {
13+
impl Accept for TcpListener {
14+
type Stream = TcpStream;
15+
16+
fn local_addr(&self) -> io::Result<AnySocketAddr> {
1517
self.0.local_addr().map(AnySocketAddr::IP)
1618
}
19+
20+
async fn accept(&self) -> io::Result<(Self::Stream, AnySocketAddr)> {
21+
let (stream, addr) = self.0.accept().await?;
22+
let io = TcpStream::from(stream);
23+
let addr = AnySocketAddr::IP(addr);
24+
Ok((io, addr))
25+
}
1726
}
1827

1928
impl UserData for TcpListener {

src/net/tcp/stream.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use std::io;
22
use std::net::SocketAddr;
33
use std::ops::{Deref, DerefMut};
4+
use std::pin::Pin;
45
use std::result::Result as StdResult;
6+
use std::task::{Context, Poll};
57

68
use mlua::{Lua, Result, String as LuaString, Table, UserData, UserDataMethods, UserDataRegistry};
7-
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
9+
use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, ReadBuf};
810
use tokio::net::lookup_host;
911

1012
use super::{SocketOptions, TcpSocket};
@@ -55,6 +57,30 @@ impl AddressProvider for TcpStream {
5557
}
5658
}
5759

60+
impl AsyncRead for TcpStream {
61+
#[inline]
62+
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
63+
Pin::new(&mut self.get_mut().stream).poll_read(cx, buf)
64+
}
65+
}
66+
67+
impl AsyncWrite for TcpStream {
68+
#[inline]
69+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
70+
Pin::new(&mut self.get_mut().stream).poll_write(cx, buf)
71+
}
72+
73+
#[inline]
74+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
75+
Pin::new(&mut self.get_mut().stream).poll_flush(cx)
76+
}
77+
78+
#[inline]
79+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
80+
Pin::new(&mut self.get_mut().stream).poll_shutdown(cx)
81+
}
82+
}
83+
5884
impl UserData for TcpStream {
5985
fn register(registry: &mut UserDataRegistry<Self>) {
6086
registry.add_async_function("connect", connect);

src/net/tls/client.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use rustls::pki_types::{CertificateDer, DnsName, ServerName, UnixTime};
1212
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
1313

1414
use super::stream::TlsStream;
15-
use crate::net::tcp::TcpStream;
1615
#[cfg(unix)]
1716
use crate::net::unix::UnixStream;
17+
use crate::net::{AnyStream, TcpStream};
1818

1919
/// TLS configuration options for client connections
2020
#[derive(Debug, Clone)]
@@ -222,9 +222,9 @@ static DEFAULT_TLS_CLIENT_CONFIG: LazyLock<TlsClientConfig> =
222222

223223
/// Wrap a stream with TLS (client-side).
224224
pub async fn wrap_stream(
225-
lua: Lua,
225+
_lua: Lua,
226226
(stream, server_name, config): (AnyUserData, Option<String>, Option<TlsClientConfig>),
227-
) -> LuaResult<StdResult<AnyUserData, String>> {
227+
) -> LuaResult<StdResult<AnyStream, String>> {
228228
let server_name = server_name.and_then(|name| ServerName::try_from(name).ok());
229229
let config = config.unwrap_or_else(|| DEFAULT_TLS_CLIENT_CONFIG.clone());
230230

@@ -236,11 +236,11 @@ pub async fn wrap_stream(
236236
.or_else(|| host.and_then(|host| ServerName::try_from(host).ok()))
237237
.or_else(|| stream.peer_addr().map(|addr| ServerName::from(addr.ip())).ok())
238238
.unwrap_or_else(default_server_name);
239-
match TlsStream::new_client(stream, server_name, config).await {
239+
match TlsStream::new_client(stream.into(), server_name, config).await {
240240
Ok(mut tls_stream) => {
241241
tls_stream.set_read_timeout(read_timeout);
242242
tls_stream.set_write_timeout(write_timeout);
243-
Ok(Ok(lua.create_userdata(tls_stream)?))
243+
Ok(Ok(AnyStream::TcpTls(tls_stream)))
244244
}
245245
Err(e) => Ok(Err(e.to_string())),
246246
}
@@ -250,11 +250,11 @@ pub async fn wrap_stream(
250250
#[rustfmt::skip]
251251
let UnixStream { stream, read_timeout, write_timeout } = stream.take::<UnixStream>()?;
252252
let server_name = server_name.unwrap_or_else(default_server_name);
253-
match TlsStream::new_client(stream, server_name, config).await {
253+
match TlsStream::new_client(stream.into(), server_name, config).await {
254254
Ok(mut tls_stream) => {
255255
tls_stream.set_read_timeout(read_timeout);
256256
tls_stream.set_write_timeout(write_timeout);
257-
Ok(Ok(lua.create_userdata(tls_stream)?))
257+
Ok(Ok(AnyStream::UnixTls(tls_stream)))
258258
}
259259
Err(e) => Ok(Err(e.to_string())),
260260
}

src/net/tls/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
33
use mlua::{Result, Table};
44

5+
pub use server::TlsListener;
6+
pub use stream::TlsStream;
7+
58
/// Registers the `tls` module in the given Lua state.
69
pub fn register(lua: &mlua::Lua, name: Option<&str>) -> Result<Table> {
710
let name = name.unwrap_or("@tls");

0 commit comments

Comments
 (0)