Skip to content

Commit bfb14fc

Browse files
committed
feat(net): async connect UnixStream
1 parent 902b711 commit bfb14fc

File tree

4 files changed

+63
-18
lines changed

4 files changed

+63
-18
lines changed

compio-net/src/unix.rs

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{future::Future, io, path::Path};
33
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
44
use compio_io::{AsyncRead, AsyncWrite};
55
use compio_runtime::{impl_attachable, impl_try_as_raw_fd};
6-
use socket2::{Domain, SockAddr, Type};
6+
use socket2::{SockAddr, Type};
77

88
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};
99

@@ -25,8 +25,8 @@ use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};
2525
/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
2626
/// let listener = UnixListener::bind(&sock_file).unwrap();
2727
///
28-
/// let mut tx = UnixStream::connect(&sock_file).unwrap();
29-
/// let (mut rx, _) = listener.accept().await.unwrap();
28+
/// let (mut tx, (mut rx, _)) =
29+
/// futures_util::try_join!(UnixStream::connect(&sock_file), listener.accept()).unwrap();
3030
///
3131
/// tx.write_all("test").await.0.unwrap();
3232
///
@@ -106,7 +106,7 @@ impl_attachable!(UnixListener, inner);
106106
///
107107
/// # compio_runtime::Runtime::new().unwrap().block_on(async {
108108
/// // Connect to a peer
109-
/// let mut stream = UnixStream::connect("unix-server.sock").unwrap();
109+
/// let mut stream = UnixStream::connect("unix-server.sock").await.unwrap();
110110
///
111111
/// // Write some data.
112112
/// stream.write("hello world!").await.unwrap();
@@ -121,16 +121,43 @@ impl UnixStream {
121121
/// Opens a Unix connection to the specified file path. There must be a
122122
/// [`UnixListener`] or equivalent listening on the corresponding Unix
123123
/// domain socket to successfully connect and return a `UnixStream`.
124-
pub fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
125-
Self::connect_addr(&SockAddr::unix(path)?)
124+
pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
125+
Self::connect_addr(&SockAddr::unix(path)?).await
126126
}
127127

128128
/// Opens a Unix connection to the specified address. There must be a
129129
/// [`UnixListener`] or equivalent listening on the corresponding Unix
130130
/// domain socket to successfully connect and return a `UnixStream`.
131-
pub fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
132-
let socket = Socket::new(Domain::UNIX, Type::STREAM, None)?;
133-
socket.connect(addr)?;
131+
pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
132+
#[cfg(windows)]
133+
let socket = {
134+
use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
135+
136+
let new_addr = unsafe {
137+
SockAddr::try_init(|addr, len| {
138+
let addr: *mut SOCKADDR_UN = addr.cast();
139+
std::ptr::write(
140+
addr,
141+
SOCKADDR_UN {
142+
sun_family: AF_UNIX,
143+
sun_path: [0; 108],
144+
},
145+
);
146+
std::ptr::write(len, 3);
147+
Ok(())
148+
})
149+
}
150+
// it is always Ok
151+
.unwrap()
152+
.1;
153+
Socket::bind(&new_addr, Type::STREAM, None)?
154+
};
155+
#[cfg(unix)]
156+
let socket = {
157+
use socket2::Domain;
158+
Socket::new(Domain::UNIX, Type::STREAM, None)?
159+
};
160+
socket.connect_async(addr).await?;
134161
let unix_stream = UnixStream { inner: socket };
135162
Ok(unix_stream)
136163
}
@@ -152,7 +179,25 @@ impl UnixStream {
152179

153180
/// Returns the socket path of the remote peer of this connection.
154181
pub fn peer_addr(&self) -> io::Result<SockAddr> {
155-
self.inner.peer_addr()
182+
#[allow(unused_mut)]
183+
let mut addr = self.inner.peer_addr()?;
184+
// The peer addr returned after ConnectEx is buggy. It contains bytes that
185+
// should not belong to the address. Luckily a unix path should not contain `\0`
186+
// until the end. We can determine the path ending by that.
187+
#[cfg(windows)]
188+
{
189+
use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
190+
191+
let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
192+
let addr_len = match std::ffi::CStr::from_bytes_until_nul(&unix_addr.sun_path) {
193+
Ok(str) => str.to_bytes_with_nul().len() + 2,
194+
Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
195+
};
196+
unsafe {
197+
addr.set_length(addr_len as _);
198+
}
199+
}
200+
Ok(addr)
156201
}
157202

158203
/// Returns the socket path of the local half of this connection.

compio-net/tests/split.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ async fn unix_split() {
6868

6969
let listener = UnixListener::bind(&sock_path).unwrap();
7070

71-
let client = UnixStream::connect(&sock_path).unwrap();
72-
let (server, _) = listener.accept().await.unwrap();
71+
let (client, (server, _)) =
72+
futures_util::try_join!(UnixStream::connect(&sock_path), listener.accept()).unwrap();
7373

7474
let (mut a_read, mut a_write) = server.into_split();
7575
let (mut b_read, mut b_write) = client.into_split();

compio-net/tests/unix_stream.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ async fn accept_read_write() -> std::io::Result<()> {
1111

1212
let listener = UnixListener::bind(&sock_path)?;
1313

14-
let mut client = UnixStream::connect(&sock_path)?;
15-
let (mut server, _) = listener.accept().await?;
14+
let (mut client, (mut server, _)) =
15+
futures_util::try_join!(UnixStream::connect(&sock_path), listener.accept()).unwrap();
1616

1717
client.write_all("hello").await.0?;
1818
drop(client);
@@ -35,8 +35,8 @@ async fn shutdown() -> std::io::Result<()> {
3535

3636
let listener = UnixListener::bind(&sock_path)?;
3737

38-
let mut client = UnixStream::connect(&sock_path)?;
39-
let (mut server, _) = listener.accept().await?;
38+
let (mut client, (mut server, _)) =
39+
futures_util::try_join!(UnixStream::connect(&sock_path), listener.accept()).unwrap();
4040

4141
// Shut down the client
4242
client.shutdown().await?;

compio/examples/unix.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ async fn main() {
1212

1313
let addr = listener.local_addr().unwrap();
1414

15-
let mut tx = UnixStream::connect_addr(&addr).unwrap();
16-
let (mut rx, _) = listener.accept().await.unwrap();
15+
let (mut tx, (mut rx, _)) =
16+
futures_util::try_join!(UnixStream::connect_addr(&addr), listener.accept()).unwrap();
1717

1818
assert_eq!(addr, tx.peer_addr().unwrap());
1919

0 commit comments

Comments
 (0)