@@ -3,7 +3,7 @@ use std::{future::Future, io, path::Path};
33use compio_buf:: { BufResult , IoBuf , IoBufMut , IoVectoredBuf , IoVectoredBufMut } ;
44use compio_io:: { AsyncRead , AsyncWrite } ;
55use compio_runtime:: { impl_attachable, impl_try_as_raw_fd} ;
6- use socket2:: { Domain , SockAddr , Type } ;
6+ use socket2:: { SockAddr , Type } ;
77
88use crate :: { OwnedReadHalf , OwnedWriteHalf , ReadHalf , Socket , WriteHalf } ;
99
@@ -25,8 +25,8 @@ use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};
2525/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
2626/// let listener = UnixListener::bind(&sock_file).unwrap();
2727///
28- /// let mut tx = UnixStream::connect(&sock_file).unwrap();
29- /// let (mut rx, _) = listener.accept().await .unwrap();
28+ /// let ( mut tx, (mut rx, _)) =
29+ /// futures_util::try_join!(UnixStream::connect(&sock_file), listener.accept()) .unwrap();
3030///
3131/// tx.write_all("test").await.0.unwrap();
3232///
@@ -52,6 +52,13 @@ impl UnixListener {
5252 /// the specified file path. The file path cannot yet exist, and will be
5353 /// cleaned up upon dropping [`UnixListener`]
5454 pub fn bind_addr ( addr : & SockAddr ) -> io:: Result < Self > {
55+ if !addr. is_unix ( ) {
56+ return Err ( io:: Error :: new (
57+ io:: ErrorKind :: InvalidInput ,
58+ "addr is not unix socket address" ,
59+ ) ) ;
60+ }
61+
5562 let socket = Socket :: bind ( addr, Type :: STREAM , None ) ?;
5663 socket. listen ( 1024 ) ?;
5764 Ok ( UnixListener { inner : socket } )
@@ -106,7 +113,7 @@ impl_attachable!(UnixListener, inner);
106113///
107114/// # compio_runtime::Runtime::new().unwrap().block_on(async {
108115/// // Connect to a peer
109- /// let mut stream = UnixStream::connect("unix-server.sock").unwrap();
116+ /// let mut stream = UnixStream::connect("unix-server.sock").await. unwrap();
110117///
111118/// // Write some data.
112119/// stream.write("hello world!").await.unwrap();
@@ -121,16 +128,32 @@ impl UnixStream {
121128 /// Opens a Unix connection to the specified file path. There must be a
122129 /// [`UnixListener`] or equivalent listening on the corresponding Unix
123130 /// domain socket to successfully connect and return a `UnixStream`.
124- pub fn connect ( path : impl AsRef < Path > ) -> io:: Result < Self > {
125- Self :: connect_addr ( & SockAddr :: unix ( path) ?)
131+ pub async fn connect ( path : impl AsRef < Path > ) -> io:: Result < Self > {
132+ Self :: connect_addr ( & SockAddr :: unix ( path) ?) . await
126133 }
127134
128135 /// Opens a Unix connection to the specified address. There must be a
129136 /// [`UnixListener`] or equivalent listening on the corresponding Unix
130137 /// domain socket to successfully connect and return a `UnixStream`.
131- pub fn connect_addr ( addr : & SockAddr ) -> io:: Result < Self > {
132- let socket = Socket :: new ( Domain :: UNIX , Type :: STREAM , None ) ?;
133- socket. connect ( addr) ?;
138+ pub async fn connect_addr ( addr : & SockAddr ) -> io:: Result < Self > {
139+ if !addr. is_unix ( ) {
140+ return Err ( io:: Error :: new (
141+ io:: ErrorKind :: InvalidInput ,
142+ "addr is not unix socket address" ,
143+ ) ) ;
144+ }
145+
146+ #[ cfg( windows) ]
147+ let socket = {
148+ let new_addr = empty_unix_socket ( ) ;
149+ Socket :: bind ( & new_addr, Type :: STREAM , None ) ?
150+ } ;
151+ #[ cfg( unix) ]
152+ let socket = {
153+ use socket2:: Domain ;
154+ Socket :: new ( Domain :: UNIX , Type :: STREAM , None ) ?
155+ } ;
156+ socket. connect_async ( addr) . await ?;
134157 let unix_stream = UnixStream { inner : socket } ;
135158 Ok ( unix_stream)
136159 }
@@ -152,7 +175,13 @@ impl UnixStream {
152175
153176 /// Returns the socket path of the remote peer of this connection.
154177 pub fn peer_addr ( & self ) -> io:: Result < SockAddr > {
155- self . inner . peer_addr ( )
178+ #[ allow( unused_mut) ]
179+ let mut addr = self . inner . peer_addr ( ) ?;
180+ #[ cfg( windows) ]
181+ {
182+ fix_unix_socket_length ( & mut addr) ;
183+ }
184+ Ok ( addr)
156185 }
157186
158187 /// Returns the socket path of the local half of this connection.
@@ -251,3 +280,47 @@ impl AsyncWrite for &UnixStream {
251280impl_try_as_raw_fd ! ( UnixStream , inner) ;
252281
253282impl_attachable ! ( UnixStream , inner) ;
283+
284+ #[ cfg( windows) ]
285+ #[ inline]
286+ fn empty_unix_socket ( ) -> SockAddr {
287+ use windows_sys:: Win32 :: Networking :: WinSock :: { AF_UNIX , SOCKADDR_UN } ;
288+
289+ // SAFETY: the length is correct
290+ unsafe {
291+ SockAddr :: try_init ( |addr, len| {
292+ let addr: * mut SOCKADDR_UN = addr. cast ( ) ;
293+ std:: ptr:: write (
294+ addr,
295+ SOCKADDR_UN {
296+ sun_family : AF_UNIX ,
297+ sun_path : [ 0 ; 108 ] ,
298+ } ,
299+ ) ;
300+ std:: ptr:: write ( len, 3 ) ;
301+ Ok ( ( ) )
302+ } )
303+ }
304+ // it is always Ok
305+ . unwrap ( )
306+ . 1
307+ }
308+
309+ // The peer addr returned after ConnectEx is buggy. It contains bytes that
310+ // should not belong to the address. Luckily a unix path should not contain `\0`
311+ // until the end. We can determine the path ending by that.
312+ #[ cfg( windows) ]
313+ #[ inline]
314+ fn fix_unix_socket_length ( addr : & mut SockAddr ) {
315+ use windows_sys:: Win32 :: Networking :: WinSock :: SOCKADDR_UN ;
316+
317+ // SAFETY: cannot construct non-unix socket address in safe way.
318+ let unix_addr: & SOCKADDR_UN = unsafe { & * addr. as_ptr ( ) . cast ( ) } ;
319+ let addr_len = match std:: ffi:: CStr :: from_bytes_until_nul ( & unix_addr. sun_path ) {
320+ Ok ( str) => str. to_bytes_with_nul ( ) . len ( ) + 2 ,
321+ Err ( _) => std:: mem:: size_of :: < SOCKADDR_UN > ( ) ,
322+ } ;
323+ unsafe {
324+ addr. set_length ( addr_len as _ ) ;
325+ }
326+ }
0 commit comments