Skip to content

Commit e42c6f1

Browse files
authored
feat(net): incoming stream (#759)
* feat(net): incoming * feat(net): tcp & unix incoming * test(net): incoming * feat(net): incoming on windows * feat(net): impl FusedStream for *Incoming * test(net): incoming uds * fix(net): deps * refactor(net): move incoming to a separate mod * fix(net): don't issue new AcceptMulti before next poll
1 parent 12324fe commit e42c6f1

File tree

9 files changed

+340
-5
lines changed

9 files changed

+340
-5
lines changed

compio-net/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ compio-runtime = { workspace = true }
2222

2323
cfg-if = { workspace = true }
2424
either = "1.9.0"
25+
futures-util = { workspace = true }
2526
once_cell = { workspace = true }
2627
socket2 = { workspace = true }
27-
futures-util = { workspace = true }
2828

2929
[target.'cfg(windows)'.dependencies]
3030
widestring = { workspace = true }

compio-net/src/incoming/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
cfg_if::cfg_if! {
2+
if #[cfg(windows)] {
3+
#[path = "windows.rs"]
4+
mod sys;
5+
} else if #[cfg(unix)] {
6+
#[path = "unix.rs"]
7+
mod sys;
8+
}
9+
}
10+
11+
pub use sys::*;

compio-net/src/incoming/unix.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use std::{
2+
io,
3+
os::fd::FromRawFd,
4+
pin::Pin,
5+
task::{Context, Poll, ready},
6+
};
7+
8+
use compio_buf::{BufResult, IntoInner};
9+
use compio_driver::{SharedFd, ToSharedFd, op::AcceptMulti};
10+
use compio_runtime::SubmitMulti;
11+
use futures_util::{Stream, StreamExt, stream::FusedStream};
12+
use socket2::Socket as Socket2;
13+
14+
use crate::Socket;
15+
16+
pub struct Incoming<'a> {
17+
listener: &'a Socket,
18+
op: Option<SubmitMulti<AcceptMulti<SharedFd<Socket2>>>>,
19+
}
20+
21+
impl<'a> Incoming<'a> {
22+
pub fn new(listener: &'a Socket) -> Self {
23+
Self { listener, op: None }
24+
}
25+
}
26+
27+
impl Stream for Incoming<'_> {
28+
type Item = io::Result<Socket>;
29+
30+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
31+
let this = self.get_mut();
32+
loop {
33+
if let Some(op) = &mut this.op {
34+
let res = ready!(op.poll_next_unpin(cx));
35+
if let Some(BufResult(res, _)) = res {
36+
let socket = if op.is_terminated() && res.is_ok() {
37+
let Some(op) = this.op.take() else {
38+
// SAFETY: op is guaranteed to be Some at this point.
39+
unsafe { std::hint::unreachable_unchecked() }
40+
};
41+
op.try_take()
42+
.map_err(|_| ())
43+
.expect("AcceptMulti has not completed")
44+
.into_inner()
45+
} else {
46+
unsafe { Socket2::from_raw_fd(res? as _) }
47+
};
48+
return Poll::Ready(Some(Socket::from_socket2(socket)));
49+
} else {
50+
this.op = None;
51+
}
52+
} else {
53+
this.op = Some(compio_runtime::submit_multi(AcceptMulti::new(
54+
this.listener.to_shared_fd(),
55+
)));
56+
}
57+
}
58+
}
59+
}
60+
61+
impl FusedStream for Incoming<'_> {
62+
fn is_terminated(&self) -> bool {
63+
false
64+
}
65+
}

compio-net/src/incoming/windows.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
use std::{
2+
io,
3+
pin::Pin,
4+
task::{Context, Poll, ready},
5+
};
6+
7+
use compio_buf::BufResult;
8+
use compio_driver::{SharedFd, ToSharedFd, op::Accept};
9+
use compio_runtime::{JoinHandle, Submit};
10+
use futures_util::{FutureExt, Stream, stream::FusedStream};
11+
use socket2::Socket as Socket2;
12+
13+
use crate::Socket;
14+
15+
#[allow(clippy::large_enum_variant)]
16+
enum IncomingState {
17+
Idle,
18+
CreatingSocket(JoinHandle<io::Result<Socket2>>),
19+
Accepting(Submit<Accept<SharedFd<Socket2>>>),
20+
}
21+
22+
pub struct Incoming<'a> {
23+
listener: &'a Socket,
24+
state: IncomingState,
25+
}
26+
27+
impl<'a> Incoming<'a> {
28+
pub fn new(listener: &'a Socket) -> Self {
29+
Self {
30+
listener,
31+
state: IncomingState::Idle,
32+
}
33+
}
34+
}
35+
36+
impl Stream for Incoming<'_> {
37+
type Item = io::Result<Socket>;
38+
39+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
40+
let this = self.get_mut();
41+
loop {
42+
match &mut this.state {
43+
IncomingState::Idle => {
44+
let domain = this.listener.local_addr().map(|addr| addr.domain())?;
45+
let ty = this.listener.socket.r#type()?;
46+
let protocol = this.listener.socket.protocol()?;
47+
let handle =
48+
compio_runtime::spawn_blocking(move || Socket2::new(domain, ty, protocol));
49+
this.state = IncomingState::CreatingSocket(handle);
50+
}
51+
IncomingState::CreatingSocket(handle) => match ready!(handle.poll_unpin(cx)) {
52+
Ok(Ok(socket)) => {
53+
let op = compio_runtime::submit(Accept::new(
54+
this.listener.to_shared_fd(),
55+
socket,
56+
));
57+
this.state = IncomingState::Accepting(op);
58+
}
59+
Ok(Err(e)) => {
60+
this.state = IncomingState::Idle;
61+
return Poll::Ready(Some(Err(e)));
62+
}
63+
Err(e) => {
64+
this.state = IncomingState::Idle;
65+
std::panic::resume_unwind(e)
66+
}
67+
},
68+
IncomingState::Accepting(op) => {
69+
let BufResult(res, op) = ready!(op.poll_unpin(cx));
70+
match res {
71+
Ok(_) => {
72+
this.state = IncomingState::Idle;
73+
op.update_context()?;
74+
let (accept_sock, _) = op.into_addr()?;
75+
return Poll::Ready(Some(Ok(Socket::from_socket2(accept_sock)?)));
76+
}
77+
Err(e) => {
78+
this.state = IncomingState::Idle;
79+
return Poll::Ready(Some(Err(e)));
80+
}
81+
}
82+
}
83+
}
84+
}
85+
}
86+
}
87+
88+
impl FusedStream for Incoming<'_> {
89+
fn is_terminated(&self) -> bool {
90+
false
91+
}
92+
}

compio-net/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
1515
)]
1616

17+
mod incoming;
1718
mod opts;
1819
mod resolve;
1920
mod socket;
@@ -46,6 +47,7 @@ pub type CMsgBuilder<'a> = compio_io::ancillary::AncillaryBuilder<'a>;
4647
/// Providing functionalities to wait for readiness.
4748
#[deprecated(since = "0.12.0", note = "Use `compio::runtime::fd::PollFd` instead")]
4849
pub type PollFd<T> = compio_runtime::fd::PollFd<T>;
50+
pub(crate) use incoming::*;
4951
pub use opts::SocketOpts;
5052
pub use resolve::ToSocketAddrsAsync;
5153
pub(crate) use resolve::{each_addr, first_addr_buf, first_addr_buf_zerocopy};

compio-net/src/socket.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ use compio_runtime::{Attacher, BorrowedBuffer, BufferPool, fd::PollFd};
2121
use futures_util::StreamExt;
2222
use socket2::{Domain, Protocol, SockAddr, Socket as Socket2, Type};
2323

24+
use crate::Incoming;
25+
2426
#[derive(Debug, Clone)]
2527
pub struct Socket {
2628
pub(crate) socket: Attacher<Socket2>,
@@ -121,6 +123,10 @@ impl Socket {
121123
Ok((Self::from_socket2(accept_sock)?, addr))
122124
}
123125

126+
pub fn incoming(&self) -> Incoming<'_> {
127+
Incoming::new(self)
128+
}
129+
124130
pub fn close(self) -> impl Future<Output = io::Result<()>> {
125131
// Make sure that self won't be dropped after `close` called.
126132
// Users may call this method and drop the future immediately. In that way the

compio-net/src/tcp.rs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
1-
use std::{future::Future, io, net::SocketAddr};
1+
use std::{
2+
future::Future,
3+
io,
4+
net::SocketAddr,
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
28

39
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
410
use compio_driver::impl_raw_fd;
511
use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
612
use compio_runtime::{BorrowedBuffer, BufferPool, fd::PollFd};
13+
use futures_util::{Stream, StreamExt, stream::FusedStream};
714
use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
815

916
use crate::{
10-
OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, ToSocketAddrsAsync, WriteHalf,
17+
Incoming, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, ToSocketAddrsAsync,
18+
WriteHalf,
1119
};
1220

1321
/// A TCP socket server, listening for connections.
@@ -120,6 +128,20 @@ impl TcpListener {
120128
Ok((stream, addr.as_socket().expect("should be SocketAddr")))
121129
}
122130

131+
/// Returns a stream of incoming connections to this listener.
132+
pub fn incoming(&self) -> TcpIncoming<'_> {
133+
self.incoming_with_options(&SocketOpts::default())
134+
}
135+
136+
/// Returns a stream of incoming connections to this listener, and sets
137+
/// options for each accepted connection.
138+
pub fn incoming_with_options<'a>(&'a self, options: &SocketOpts) -> TcpIncoming<'a> {
139+
TcpIncoming {
140+
inner: self.inner.incoming(),
141+
opts: *options,
142+
}
143+
}
144+
123145
/// Returns the local address that this listener is bound to.
124146
///
125147
/// This can be useful, for example, when binding to port 0 to
@@ -152,6 +174,33 @@ impl TcpListener {
152174

153175
impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
154176

177+
/// A stream of incoming TCP connections.
178+
pub struct TcpIncoming<'a> {
179+
inner: Incoming<'a>,
180+
opts: SocketOpts,
181+
}
182+
183+
impl Stream for TcpIncoming<'_> {
184+
type Item = io::Result<TcpStream>;
185+
186+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187+
let this = self.get_mut();
188+
this.inner.poll_next_unpin(cx).map(|res| {
189+
res.map(|res| {
190+
let socket = res?;
191+
this.opts.setup_socket(&socket)?;
192+
Ok(TcpStream { inner: socket })
193+
})
194+
})
195+
}
196+
}
197+
198+
impl FusedStream for TcpIncoming<'_> {
199+
fn is_terminated(&self) -> bool {
200+
self.inner.is_terminated()
201+
}
202+
}
203+
155204
/// A TCP stream between a local and a remote socket.
156205
///
157206
/// A TCP stream can either be created by connecting to an endpoint, via the

compio-net/src/unix.rs

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1-
use std::{future::Future, io, path::Path};
1+
use std::{
2+
future::Future,
3+
io,
4+
path::Path,
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
28

39
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
410
use compio_driver::impl_raw_fd;
511
use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
612
use compio_runtime::{BorrowedBuffer, BufferPool, fd::PollFd};
13+
use futures_util::{Stream, StreamExt, stream::FusedStream};
714
use socket2::{SockAddr, Socket as Socket2, Type};
815

9-
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, WriteHalf};
16+
use crate::{Incoming, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, WriteHalf};
1017

1118
/// A Unix socket server, listening for connections.
1219
///
@@ -113,6 +120,20 @@ impl UnixListener {
113120
Ok((stream, addr))
114121
}
115122

123+
/// Returns a stream of incoming connections to this listener.
124+
pub fn incoming(&self) -> UnixIncoming<'_> {
125+
self.incoming_with_options(&SocketOpts::default())
126+
}
127+
128+
/// Returns a stream of incoming connections to this listener, and sets
129+
/// options for each accepted connection.
130+
pub fn incoming_with_options<'a>(&'a self, options: &SocketOpts) -> UnixIncoming<'a> {
131+
UnixIncoming {
132+
inner: self.inner.incoming(),
133+
opts: *options,
134+
}
135+
}
136+
116137
/// Returns the local address that this listener is bound to.
117138
pub fn local_addr(&self) -> io::Result<SockAddr> {
118139
self.inner.local_addr()
@@ -121,6 +142,33 @@ impl UnixListener {
121142

122143
impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
123144

145+
/// A stream of incoming Unix connections.
146+
pub struct UnixIncoming<'a> {
147+
inner: Incoming<'a>,
148+
opts: SocketOpts,
149+
}
150+
151+
impl Stream for UnixIncoming<'_> {
152+
type Item = io::Result<UnixStream>;
153+
154+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
155+
let this = self.get_mut();
156+
this.inner.poll_next_unpin(cx).map(|res| {
157+
res.map(|res| {
158+
let socket = res?;
159+
this.opts.setup_socket(&socket)?;
160+
Ok(UnixStream { inner: socket })
161+
})
162+
})
163+
}
164+
}
165+
166+
impl FusedStream for UnixIncoming<'_> {
167+
fn is_terminated(&self) -> bool {
168+
self.inner.is_terminated()
169+
}
170+
}
171+
124172
/// A Unix stream between two local sockets on Windows & WSL.
125173
///
126174
/// A Unix stream can either be created by connecting to an endpoint, via the

0 commit comments

Comments
 (0)