Skip to content

Commit b0fc2af

Browse files
committed
Add Lua prefix to net types
1 parent 17c5cd3 commit b0fc2af

File tree

15 files changed

+261
-171
lines changed

15 files changed

+261
-171
lines changed

src/net/common.rs

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ use std::{fmt, io};
88

99
use mlua::{AnyUserData, Error, FromLua, IntoLua, Lua, MaybeSend, Result, Value};
1010
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11+
use tokio::net::TcpStream;
12+
#[cfg(unix)]
13+
use tokio::net::UnixStream;
1114

12-
use super::tcp::{TcpListener, TcpStream};
15+
use super::tcp::{LuaTcpListener, LuaTcpStream};
1316
#[cfg(feature = "tls")]
14-
use super::tls::{TlsListener, TlsStream};
17+
use super::tls::{LuaTlsListener, LuaTlsStream};
1518
#[cfg(unix)]
16-
use super::unix::{UnixListener, UnixStream};
19+
use super::unix::{LuaUnixListener, LuaUnixStream};
1720

1821
/// Socket address that can be either TCP or Unix domain socket.
1922
pub enum AnySocketAddr {
@@ -50,7 +53,7 @@ pub trait AddressProvider {
5053
fn peer_addr(&self) -> io::Result<AnySocketAddr>;
5154
}
5255

53-
impl AddressProvider for tokio::net::TcpStream {
56+
impl AddressProvider for TcpStream {
5457
fn local_addr(&self) -> io::Result<AnySocketAddr> {
5558
Ok(AnySocketAddr::IP(self.local_addr()?))
5659
}
@@ -61,7 +64,7 @@ impl AddressProvider for tokio::net::TcpStream {
6164
}
6265

6366
#[cfg(unix)]
64-
impl AddressProvider for tokio::net::UnixStream {
67+
impl AddressProvider for UnixStream {
6568
fn local_addr(&self) -> io::Result<AnySocketAddr> {
6669
Ok(AnySocketAddr::Unix(self.local_addr()?))
6770
}
@@ -73,13 +76,13 @@ impl AddressProvider for tokio::net::UnixStream {
7376

7477
/// A stream that can be either TCP or Unix domain socket, possibly wrapped in TLS.
7578
pub enum AnyStream {
76-
Tcp(TcpStream),
79+
Tcp(LuaTcpStream),
7780
#[cfg(unix)]
78-
Unix(UnixStream),
81+
Unix(LuaUnixStream),
7982
#[cfg(feature = "tls")]
80-
TcpTls(TlsStream<TcpStream>),
83+
TcpTls(LuaTlsStream<LuaTcpStream>),
8184
#[cfg(all(unix, feature = "tls"))]
82-
UnixTls(TlsStream<UnixStream>),
85+
UnixTls(LuaTlsStream<LuaUnixStream>),
8386
}
8487

8588
impl AsyncRead for AnyStream {
@@ -152,23 +155,23 @@ impl FromLua for AnyStream {
152155
fn from_lua(value: Value, lua: &Lua) -> Result<Self> {
153156
let value = lua.unpack::<AnyUserData>(value)?;
154157
match value.type_id() {
155-
Some(id) if id == TypeId::of::<TcpStream>() => {
156-
let stream = value.take::<TcpStream>()?;
158+
Some(id) if id == TypeId::of::<LuaTcpStream>() => {
159+
let stream = value.take::<LuaTcpStream>()?;
157160
Ok(AnyStream::Tcp(stream))
158161
}
159162
#[cfg(unix)]
160-
Some(id) if id == TypeId::of::<UnixStream>() => {
161-
let stream = value.take::<UnixStream>()?;
163+
Some(id) if id == TypeId::of::<LuaUnixStream>() => {
164+
let stream = value.take::<LuaUnixStream>()?;
162165
Ok(AnyStream::Unix(stream))
163166
}
164167
#[cfg(feature = "tls")]
165-
Some(id) if id == TypeId::of::<TlsStream<TcpStream>>() => {
166-
let stream = value.take::<TlsStream<TcpStream>>()?;
168+
Some(id) if id == TypeId::of::<LuaTlsStream<LuaTcpStream>>() => {
169+
let stream = value.take::<LuaTlsStream<LuaTcpStream>>()?;
167170
Ok(AnyStream::TcpTls(stream))
168171
}
169172
#[cfg(all(unix, feature = "tls"))]
170-
Some(id) if id == TypeId::of::<TlsStream<UnixStream>>() => {
171-
let stream = value.take::<TlsStream<UnixStream>>()?;
173+
Some(id) if id == TypeId::of::<LuaTlsStream<LuaUnixStream>>() => {
174+
let stream = value.take::<LuaTlsStream<LuaUnixStream>>()?;
172175
Ok(AnyStream::UnixTls(stream))
173176
}
174177
_ => {
@@ -186,13 +189,13 @@ impl FromLua for AnyStream {
186189

187190
/// A listener that can be either TCP or Unix domain socket, possibly wrapped in TLS.
188191
pub enum AnyListener {
189-
Tcp(TcpListener),
192+
Tcp(LuaTcpListener),
190193
#[cfg(unix)]
191-
Unix(UnixListener),
194+
Unix(LuaUnixListener),
192195
#[cfg(feature = "tls")]
193-
TcpTls(TlsListener<TcpListener>),
196+
TcpTls(LuaTlsListener<LuaTcpListener>),
194197
#[cfg(all(unix, feature = "tls"))]
195-
UnixTls(TlsListener<UnixListener>),
198+
UnixTls(LuaTlsListener<LuaUnixListener>),
196199
}
197200

198201
/// Trait for accepting incoming connections from various listener types.
@@ -251,23 +254,23 @@ impl FromLua for AnyListener {
251254
fn from_lua(value: Value, lua: &Lua) -> Result<Self> {
252255
let value = lua.unpack::<AnyUserData>(value)?;
253256
match value.type_id() {
254-
Some(id) if id == TypeId::of::<TcpListener>() => {
255-
let listener = value.take::<TcpListener>()?;
257+
Some(id) if id == TypeId::of::<LuaTcpListener>() => {
258+
let listener = value.take::<LuaTcpListener>()?;
256259
Ok(AnyListener::Tcp(listener))
257260
}
258261
#[cfg(unix)]
259-
Some(id) if id == TypeId::of::<UnixListener>() => {
260-
let listener = value.take::<UnixListener>()?;
262+
Some(id) if id == TypeId::of::<LuaUnixListener>() => {
263+
let listener = value.take::<LuaUnixListener>()?;
261264
Ok(AnyListener::Unix(listener))
262265
}
263266
#[cfg(feature = "tls")]
264-
Some(id) if id == TypeId::of::<TlsListener<TcpListener>>() => {
265-
let listener = value.take::<TlsListener<TcpListener>>()?;
267+
Some(id) if id == TypeId::of::<LuaTlsListener<LuaTcpListener>>() => {
268+
let listener = value.take::<LuaTlsListener<LuaTcpListener>>()?;
266269
Ok(AnyListener::TcpTls(listener))
267270
}
268271
#[cfg(all(unix, feature = "tls"))]
269-
Some(id) if id == TypeId::of::<TlsListener<UnixListener>>() => {
270-
let listener = value.take::<TlsListener<UnixListener>>()?;
272+
Some(id) if id == TypeId::of::<LuaTlsListener<LuaUnixListener>>() => {
273+
let listener = value.take::<LuaTlsListener<LuaUnixListener>>()?;
271274
Ok(AnyListener::UnixTls(listener))
272275
}
273276
_ => {

src/net/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use mlua::{Lua, Result, Table};
22

33
pub use common::{AddressProvider, AnyListener, AnySocketAddr, AnyStream};
4-
pub use tcp::{TcpListener, TcpStream};
4+
pub use tcp::{LuaTcpListener, LuaTcpStream};
55

66
/// A loader for the `net` module.
77
fn loader(lua: &Lua) -> Result<Table> {

src/net/tcp/listener.rs

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,94 @@
11
use std::io;
22
use std::net::SocketAddr;
3+
use std::ops::{Deref, DerefMut};
34
use std::result::Result as StdResult;
45

56
use mlua::{Lua, Result, Table, UserData, UserDataMethods, UserDataRegistry};
6-
use tokio::net::lookup_host;
7+
use tokio::net::{TcpListener, lookup_host};
78

8-
use super::{SocketOptions, TcpSocket, TcpStream};
9+
use super::{LuaTcpSocket, LuaTcpStream, SocketOptions};
910
use crate::net::common::{Accept, AnySocketAddr};
1011

11-
pub struct TcpListener(pub(crate) tokio::net::TcpListener);
12+
/// Lua wrapper around tokio [`TcpListener`].
13+
#[derive(Debug)]
14+
pub struct LuaTcpListener(pub(crate) TcpListener);
1215

13-
impl Accept for TcpListener {
14-
type Stream = TcpStream;
16+
impl Deref for LuaTcpListener {
17+
type Target = TcpListener;
18+
19+
#[inline]
20+
fn deref(&self) -> &Self::Target {
21+
&self.0
22+
}
23+
}
24+
25+
impl DerefMut for LuaTcpListener {
26+
#[inline]
27+
fn deref_mut(&mut self) -> &mut Self::Target {
28+
&mut self.0
29+
}
30+
}
31+
32+
impl From<TcpListener> for LuaTcpListener {
33+
#[inline]
34+
fn from(listener: TcpListener) -> Self {
35+
LuaTcpListener(listener)
36+
}
37+
}
38+
39+
impl Accept for LuaTcpListener {
40+
type Stream = LuaTcpStream;
1541

1642
fn local_addr(&self) -> io::Result<AnySocketAddr> {
1743
self.0.local_addr().map(AnySocketAddr::IP)
1844
}
1945

2046
async fn accept(&self) -> io::Result<(Self::Stream, AnySocketAddr)> {
2147
let (stream, addr) = self.0.accept().await?;
22-
let io = TcpStream::from(stream);
23-
let addr = AnySocketAddr::IP(addr);
24-
Ok((io, addr))
48+
Ok((stream.into(), AnySocketAddr::IP(addr)))
2549
}
2650
}
2751

28-
impl UserData for TcpListener {
52+
impl UserData for LuaTcpListener {
2953
fn register(registry: &mut UserDataRegistry<Self>) {
3054
registry.add_method("local_addr", |_, this, ()| Ok(this.local_addr()?));
3155

3256
registry.add_async_function("listen", listen);
3357

3458
registry.add_async_method("accept", |_, this, ()| async move {
3559
let (stream, _) = lua_try!(this.0.accept().await);
36-
Ok(Ok(TcpStream::from(stream)))
60+
Ok(Ok(LuaTcpStream::from(stream)))
3761
});
3862
}
3963
}
4064

65+
/// Creates a TCP listener bound to the specified address.
66+
///
67+
/// # Arguments
68+
/// * `addr`: The address to bind to.
69+
/// * `port`: The port to bind to. If `None`, a random available port will be used.
70+
/// * `params` (optional): A table of socket options.
71+
///
72+
/// The following options can be specified:
73+
/// * `backlog`: The maximum number of pending connections. Default is 1024.
4174
pub async fn listen(
4275
_: Lua,
4376
(addr, port, params): (String, Option<u16>, Option<Table>),
44-
) -> Result<StdResult<TcpListener, String>> {
77+
) -> Result<StdResult<LuaTcpListener, String>> {
4578
let port = port.unwrap_or(0);
4679
let addrs = lua_try!(lookup_host((addr, port)).await);
4780

4881
let sock_options = SocketOptions::from_table(&params)?;
4982
let backlog = opt_param!(params, "backlog")?;
5083

5184
let try_listen = |addr: SocketAddr| {
52-
let sock = TcpSocket::new_for_addr(addr)?;
85+
let sock = LuaTcpSocket::new_for_addr(addr)?;
5386
sock.set_options(sock_options)?;
5487
sock.0.set_reuseaddr(true)?;
5588
sock.0.bind(addr)?;
56-
let listener = TcpListener(sock.0.listen(backlog.unwrap_or(1024))?);
57-
io::Result::Ok(listener)
89+
let backlog = backlog.unwrap_or(1024);
90+
let listener = sock.0.listen(backlog)?;
91+
io::Result::Ok(listener.into())
5892
};
5993

6094
let mut last_err = None;

src/net/tcp/mod.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
use mlua::{Lua, Result, Table};
22

3-
pub use listener::{TcpListener, listen};
4-
pub use stream::{TcpStream, connect};
3+
pub use listener::{LuaTcpListener, listen};
4+
pub use stream::{LuaTcpStream, connect};
55

6-
use socket::{SocketOptions, TcpSocket};
6+
use socket::{SocketOptions, LuaTcpSocket};
77

88
/// A loader for the `net/tcp` module.
99
fn loader(lua: &Lua) -> Result<Table> {
1010
let t = lua.create_table()?;
11-
t.set("TcpListener", lua.create_proxy::<TcpListener>()?)?;
12-
t.set("TcpStream", lua.create_proxy::<TcpStream>()?)?;
11+
t.set("TcpListener", lua.create_proxy::<LuaTcpListener>()?)?;
12+
t.set("TcpStream", lua.create_proxy::<LuaTcpStream>()?)?;
1313
t.set("listen", lua.create_async_function(listen)?)?;
1414
t.set("connect", lua.create_async_function(connect)?)?;
1515
Ok(t)

src/net/tcp/socket.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ use std::net::SocketAddr;
33
use std::ops::Deref;
44

55
use mlua::{Result, Table};
6+
use tokio::net::TcpSocket;
67

7-
pub(crate) struct TcpSocket(pub(crate) tokio::net::TcpSocket);
8+
pub(crate) struct LuaTcpSocket(pub(crate) TcpSocket);
89

9-
impl Deref for TcpSocket {
10-
type Target = tokio::net::TcpSocket;
10+
impl Deref for LuaTcpSocket {
11+
type Target = TcpSocket;
1112

1213
#[inline]
1314
fn deref(&self) -> &Self::Target {
@@ -25,13 +26,13 @@ pub(super) struct SocketOptions {
2526
reuseport: Option<bool>,
2627
}
2728

28-
impl TcpSocket {
29+
impl LuaTcpSocket {
2930
pub(crate) fn new_for_addr(addr: SocketAddr) -> io::Result<Self> {
3031
let sock = match addr {
31-
SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
32-
SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
32+
SocketAddr::V4(_) => TcpSocket::new_v4()?,
33+
SocketAddr::V6(_) => TcpSocket::new_v6()?,
3334
};
34-
Ok(TcpSocket(sock))
35+
Ok(LuaTcpSocket(sock))
3536
}
3637

3738
pub(crate) fn set_options(&self, options: SocketOptions) -> io::Result<()> {

0 commit comments

Comments
 (0)