Skip to content

Commit 376000c

Browse files
authored
Merge pull request #212 from Berrysoft:fix/unix-connect
feat(net): async connect unix stream
2 parents c349605 + ca01e3f commit 376000c

File tree

4 files changed

+91
-18
lines changed

4 files changed

+91
-18
lines changed

compio-net/src/unix.rs

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{future::Future, io, path::Path};
33
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
44
use compio_io::{AsyncRead, AsyncWrite};
55
use compio_runtime::{impl_attachable, impl_try_as_raw_fd};
6-
use socket2::{Domain, SockAddr, Type};
6+
use socket2::{SockAddr, Type};
77

88
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};
99

@@ -25,8 +25,8 @@ use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};
2525
/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
2626
/// let listener = UnixListener::bind(&sock_file).unwrap();
2727
///
28-
/// let mut tx = UnixStream::connect(&sock_file).unwrap();
29-
/// let (mut rx, _) = listener.accept().await.unwrap();
28+
/// let (mut tx, (mut rx, _)) =
29+
/// futures_util::try_join!(UnixStream::connect(&sock_file), listener.accept()).unwrap();
3030
///
3131
/// tx.write_all("test").await.0.unwrap();
3232
///
@@ -52,6 +52,13 @@ impl UnixListener {
5252
/// the specified file path. The file path cannot yet exist, and will be
5353
/// cleaned up upon dropping [`UnixListener`]
5454
pub fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
55+
if !addr.is_unix() {
56+
return Err(io::Error::new(
57+
io::ErrorKind::InvalidInput,
58+
"addr is not unix socket address",
59+
));
60+
}
61+
5562
let socket = Socket::bind(addr, Type::STREAM, None)?;
5663
socket.listen(1024)?;
5764
Ok(UnixListener { inner: socket })
@@ -106,7 +113,7 @@ impl_attachable!(UnixListener, inner);
106113
///
107114
/// # compio_runtime::Runtime::new().unwrap().block_on(async {
108115
/// // Connect to a peer
109-
/// let mut stream = UnixStream::connect("unix-server.sock").unwrap();
116+
/// let mut stream = UnixStream::connect("unix-server.sock").await.unwrap();
110117
///
111118
/// // Write some data.
112119
/// stream.write("hello world!").await.unwrap();
@@ -121,16 +128,32 @@ impl UnixStream {
121128
/// Opens a Unix connection to the specified file path. There must be a
122129
/// [`UnixListener`] or equivalent listening on the corresponding Unix
123130
/// domain socket to successfully connect and return a `UnixStream`.
124-
pub fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
125-
Self::connect_addr(&SockAddr::unix(path)?)
131+
pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
132+
Self::connect_addr(&SockAddr::unix(path)?).await
126133
}
127134

128135
/// Opens a Unix connection to the specified address. There must be a
129136
/// [`UnixListener`] or equivalent listening on the corresponding Unix
130137
/// domain socket to successfully connect and return a `UnixStream`.
131-
pub fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
132-
let socket = Socket::new(Domain::UNIX, Type::STREAM, None)?;
133-
socket.connect(addr)?;
138+
pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
139+
if !addr.is_unix() {
140+
return Err(io::Error::new(
141+
io::ErrorKind::InvalidInput,
142+
"addr is not unix socket address",
143+
));
144+
}
145+
146+
#[cfg(windows)]
147+
let socket = {
148+
let new_addr = empty_unix_socket();
149+
Socket::bind(&new_addr, Type::STREAM, None)?
150+
};
151+
#[cfg(unix)]
152+
let socket = {
153+
use socket2::Domain;
154+
Socket::new(Domain::UNIX, Type::STREAM, None)?
155+
};
156+
socket.connect_async(addr).await?;
134157
let unix_stream = UnixStream { inner: socket };
135158
Ok(unix_stream)
136159
}
@@ -152,7 +175,13 @@ impl UnixStream {
152175

153176
/// Returns the socket path of the remote peer of this connection.
154177
pub fn peer_addr(&self) -> io::Result<SockAddr> {
155-
self.inner.peer_addr()
178+
#[allow(unused_mut)]
179+
let mut addr = self.inner.peer_addr()?;
180+
#[cfg(windows)]
181+
{
182+
fix_unix_socket_length(&mut addr);
183+
}
184+
Ok(addr)
156185
}
157186

158187
/// Returns the socket path of the local half of this connection.
@@ -251,3 +280,47 @@ impl AsyncWrite for &UnixStream {
251280
impl_try_as_raw_fd!(UnixStream, inner);
252281

253282
impl_attachable!(UnixStream, inner);
283+
284+
#[cfg(windows)]
285+
#[inline]
286+
fn empty_unix_socket() -> SockAddr {
287+
use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
288+
289+
// SAFETY: the length is correct
290+
unsafe {
291+
SockAddr::try_init(|addr, len| {
292+
let addr: *mut SOCKADDR_UN = addr.cast();
293+
std::ptr::write(
294+
addr,
295+
SOCKADDR_UN {
296+
sun_family: AF_UNIX,
297+
sun_path: [0; 108],
298+
},
299+
);
300+
std::ptr::write(len, 3);
301+
Ok(())
302+
})
303+
}
304+
// it is always Ok
305+
.unwrap()
306+
.1
307+
}
308+
309+
// The peer addr returned after ConnectEx is buggy. It contains bytes that
310+
// should not belong to the address. Luckily a unix path should not contain `\0`
311+
// until the end. We can determine the path ending by that.
312+
#[cfg(windows)]
313+
#[inline]
314+
fn fix_unix_socket_length(addr: &mut SockAddr) {
315+
use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
316+
317+
// SAFETY: cannot construct non-unix socket address in safe way.
318+
let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
319+
let addr_len = match std::ffi::CStr::from_bytes_until_nul(&unix_addr.sun_path) {
320+
Ok(str) => str.to_bytes_with_nul().len() + 2,
321+
Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
322+
};
323+
unsafe {
324+
addr.set_length(addr_len as _);
325+
}
326+
}

compio-net/tests/split.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ async fn unix_split() {
6868

6969
let listener = UnixListener::bind(&sock_path).unwrap();
7070

71-
let client = UnixStream::connect(&sock_path).unwrap();
72-
let (server, _) = listener.accept().await.unwrap();
71+
let (client, (server, _)) =
72+
futures_util::try_join!(UnixStream::connect(&sock_path), listener.accept()).unwrap();
7373

7474
let (mut a_read, mut a_write) = server.into_split();
7575
let (mut b_read, mut b_write) = client.into_split();

compio-net/tests/unix_stream.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ async fn accept_read_write() -> std::io::Result<()> {
1111

1212
let listener = UnixListener::bind(&sock_path)?;
1313

14-
let mut client = UnixStream::connect(&sock_path)?;
15-
let (mut server, _) = listener.accept().await?;
14+
let (mut client, (mut server, _)) =
15+
futures_util::try_join!(UnixStream::connect(&sock_path), listener.accept()).unwrap();
1616

1717
client.write_all("hello").await.0?;
1818
drop(client);
@@ -35,8 +35,8 @@ async fn shutdown() -> std::io::Result<()> {
3535

3636
let listener = UnixListener::bind(&sock_path)?;
3737

38-
let mut client = UnixStream::connect(&sock_path)?;
39-
let (mut server, _) = listener.accept().await?;
38+
let (mut client, (mut server, _)) =
39+
futures_util::try_join!(UnixStream::connect(&sock_path), listener.accept()).unwrap();
4040

4141
// Shut down the client
4242
client.shutdown().await?;

compio/examples/unix.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ async fn main() {
1212

1313
let addr = listener.local_addr().unwrap();
1414

15-
let mut tx = UnixStream::connect_addr(&addr).unwrap();
16-
let (mut rx, _) = listener.accept().await.unwrap();
15+
let (mut tx, (mut rx, _)) =
16+
futures_util::try_join!(UnixStream::connect_addr(&addr), listener.accept()).unwrap();
1717

1818
assert_eq!(addr, tx.peer_addr().unwrap());
1919

0 commit comments

Comments
 (0)