Skip to content

Commit f72899c

Browse files
DVNghiemgi0baro
andauthored
Add UDP support (#46)
* Implement UDP socket support and related transport management in the event loop * Refactor UDP socket connection and streamline datagram endpoint tests * Refactor UDP impl * Code lint * Add UDP write buffering --------- Co-authored-by: Giovanni Barillari <giovanni.barillari@sentry.io>
1 parent faf5cc6 commit f72899c

File tree

7 files changed

+727
-55
lines changed

7 files changed

+727
-55
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ lint-rust:
2727
-D warnings \
2828
-W clippy::pedantic \
2929
-W clippy::dbg_macro \
30+
-A clippy::blocks_in_conditions \
3031
-A clippy::cast-possible-truncation \
3132
-A clippy::cast-sign-loss \
3233
-A clippy::declare-interior-mutable-const \

rloop/loop.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,12 +655,79 @@ async def create_datagram_endpoint(
655655
family=0,
656656
proto=0,
657657
flags=0,
658-
reuse_address=None,
658+
#: not in stdlib
659+
# reuse_address=None,
659660
reuse_port=None,
660661
allow_broadcast=None,
661662
sock=None,
662663
):
663-
raise NotImplementedError
664+
if sock is not None:
665+
if getattr(sock, 'type', None) != socket.SOCK_DGRAM:
666+
raise ValueError(f'A datagram socket was expected, got {sock!r}')
667+
if any((local_addr, remote_addr, family, proto, flags, reuse_port, allow_broadcast)):
668+
raise ValueError('socket modifier keyword arguments can not be used when sock is specified.')
669+
sock.setblocking(False)
670+
r_addr = None
671+
else:
672+
if not (local_addr or remote_addr):
673+
if family == 0:
674+
raise ValueError('unexpected address family')
675+
addr_info = (family, proto, None, None)
676+
elif hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
677+
for addr in (local_addr, remote_addr):
678+
if addr is not None and not isinstance(addr, str):
679+
raise TypeError('string is expected')
680+
addr_info = (family, proto, local_addr, remote_addr)
681+
else:
682+
addr_info, infos = None, None
683+
for addr in (local_addr, remote_addr):
684+
if addr is None:
685+
continue
686+
if not (isinstance(addr, tuple) and len(addr) == 2):
687+
raise TypeError('2-tuple is expected')
688+
infos = await self._ensure_resolved(
689+
addr, family=family, type=socket.SOCK_DGRAM, proto=proto, flags=flags
690+
)
691+
break
692+
693+
if not infos:
694+
raise OSError('getaddrinfo() returned empty list')
695+
if local_addr is not None:
696+
addr_info = (infos[0][0], infos[0][2], infos[0][4], None)
697+
if remote_addr is not None:
698+
addr_info = (infos[0][0], infos[0][2], None, infos[0][4])
699+
if not addr_info:
700+
raise ValueError('can not get address information')
701+
702+
sock = None
703+
r_addr = None
704+
sfam, spro, sladdr, sraddr = addr_info
705+
try:
706+
sock = socket.socket(family=sfam, type=socket.SOCK_DGRAM, proto=spro)
707+
#: not in stdlib
708+
# if reuse_address:
709+
# sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
710+
if reuse_port:
711+
_set_reuseport(sock)
712+
if allow_broadcast:
713+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
714+
sock.setblocking(False)
715+
if sladdr:
716+
sock.bind(sladdr)
717+
if sraddr:
718+
if not allow_broadcast:
719+
await self.sock_connect(sock, sraddr)
720+
r_addr = sraddr
721+
except OSError:
722+
if sock is not None:
723+
sock.close()
724+
raise
725+
726+
# Create the transport
727+
transport, protocol = self._udp_conn((sock.fileno(), sock.family), protocol_factory, r_addr)
728+
# sock is now owned by the transport, prevent close
729+
sock.detach()
730+
return transport, protocol
664731

665732
#: pipes and subprocesses methods
666733
async def connect_read_pipe(self, protocol_factory, pipe):

src/event_loop.rs

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ use crate::{
1919
server::Server,
2020
tcp::{TCPReadHandle, TCPServer, TCPServerRef, TCPTransport, TCPWriteHandle},
2121
time::Timer,
22+
udp::{UDPReadHandle, UDPTransport, UDPWriteHandle},
2223
};
2324

2425
enum IOHandle {
2526
Py(PyHandleData),
2627
Signals,
2728
TCPListener(TCPListenerHandleData),
2829
TCPStream(Interest),
30+
UDPSocket(Interest),
2931
}
3032

3133
struct PyHandleData {
@@ -73,6 +75,7 @@ pub struct EventLoop {
7375
task_factory: RwLock<PyObject>,
7476
tcp_lstreams: papaya::HashMap<usize, papaya::HashSet<usize>>,
7577
tcp_transports: papaya::HashMap<usize, Py<TCPTransport>>,
78+
udp_transports: papaya::HashMap<usize, Py<UDPTransport>>,
7679
thread_id: atomic::AtomicI64,
7780
watcher_child: RwLock<PyObject>,
7881
#[pyo3(get)]
@@ -149,6 +152,7 @@ impl EventLoop {
149152
IOHandle::Py(handle) => self.handle_io_py(py, event, handle, &mut cb_handles),
150153
IOHandle::TCPListener(handle) => self.handle_io_tcpl(py, handle, &io_handles, &mut cb_handles),
151154
IOHandle::TCPStream(_) => self.handle_io_tcps(event, &mut cb_handles),
155+
IOHandle::UDPSocket(_) => self.handle_io_udp(event, &mut cb_handles),
152156
IOHandle::Signals => self.handle_io_signals(py, &mut state.buf, &mut cb_handles),
153157
}
154158
}
@@ -231,7 +235,7 @@ impl EventLoop {
231235
let fd = stream.as_raw_fd() as usize;
232236
let token = Token(fd);
233237
#[allow(clippy::cast_possible_wrap)]
234-
let mut source = Source::TCPStream(fd as i32);
238+
let mut source = Source::FD(fd as i32);
235239
let (pytransport, stream_handle) = handle.server.new_stream(py, stream);
236240
transports.insert(fd, pytransport);
237241
lstreams.insert(fd);
@@ -254,6 +258,16 @@ impl EventLoop {
254258
}
255259
}
256260

261+
#[inline]
262+
fn handle_io_udp(&self, event: &event::Event, handles_ready: &mut VecDeque<BoxedHandle>) {
263+
let fd = event.token().0;
264+
if event.is_readable() {
265+
handles_ready.push_back(Box::new(UDPReadHandle { fd }));
266+
} else if event.is_writable() {
267+
handles_ready.push_back(Box::new(UDPWriteHandle { fd }));
268+
}
269+
}
270+
257271
#[inline]
258272
fn handle_io_signals(&self, py: Python, buf: &mut [u8], handles_ready: &mut VecDeque<BoxedHandle>) {
259273
let mut sock_guard = self.ssock.write().unwrap();
@@ -338,7 +352,7 @@ impl EventLoop {
338352
},
339353
|| {
340354
#[allow(clippy::cast_possible_wrap)]
341-
let mut source = Source::TCPStream(fd as i32);
355+
let mut source = Source::FD(fd as i32);
342356
{
343357
let guard_poll = self.io.lock().unwrap();
344358
_ = guard_poll.registry().register(&mut source, token, interest);
@@ -403,6 +417,83 @@ impl EventLoop {
403417
}
404418
}
405419

420+
#[inline]
421+
pub(crate) fn udp_socket_add(&self, fd: usize, interest: Interest) {
422+
let token = Token(fd);
423+
self.handles_io.pin().update_or_insert_with(
424+
token,
425+
|io_handle| {
426+
if let IOHandle::UDPSocket(interest_prev) = io_handle {
427+
if *interest_prev == interest {
428+
return IOHandle::UDPSocket(interest);
429+
}
430+
431+
let interests = *interest_prev | interest;
432+
{
433+
#[allow(clippy::cast_possible_wrap)]
434+
let mut source = Source::FD(fd as i32);
435+
let guard_poll = self.io.lock().unwrap();
436+
_ = guard_poll.registry().reregister(&mut source, token, interests);
437+
}
438+
return IOHandle::UDPSocket(interests);
439+
}
440+
unreachable!()
441+
},
442+
|| {
443+
#[allow(clippy::cast_possible_wrap)]
444+
let mut source = Source::FD(fd as i32);
445+
{
446+
let guard_poll = self.io.lock().unwrap();
447+
_ = guard_poll.registry().register(&mut source, token, interest);
448+
}
449+
IOHandle::UDPSocket(interest)
450+
},
451+
);
452+
}
453+
454+
#[inline]
455+
pub(crate) fn udp_socket_rem(&self, fd: usize, interest: Interest) {
456+
let token = Token(fd);
457+
458+
match self.handles_io.pin().remove_if(&token, |_, io_handle| {
459+
if let IOHandle::UDPSocket(interest_ex) = io_handle {
460+
return *interest_ex == interest;
461+
}
462+
false
463+
}) {
464+
Ok(None) => {}
465+
Ok(_) => {
466+
#[allow(clippy::cast_possible_wrap)]
467+
let mut source = Source::FD(fd as i32);
468+
let guard_poll = self.io.lock().unwrap();
469+
_ = guard_poll.registry().deregister(&mut source);
470+
}
471+
_ => {
472+
self.handles_io.pin().update(token, |io_handle| {
473+
if let IOHandle::UDPSocket(interest_ex) = io_handle {
474+
let interest_new = interest_ex.remove(interest).unwrap();
475+
#[allow(clippy::cast_possible_wrap)]
476+
let mut source = Source::FD(fd as i32);
477+
let guard_poll = self.io.lock().unwrap();
478+
_ = guard_poll.registry().reregister(&mut source, token, interest_new);
479+
return IOHandle::UDPSocket(interest_new);
480+
}
481+
unreachable!()
482+
});
483+
}
484+
}
485+
}
486+
487+
#[inline]
488+
pub(crate) fn udp_socket_close(&self, fd: usize) {
489+
self.udp_transports.pin().remove(&fd);
490+
}
491+
492+
#[inline(always)]
493+
pub(crate) fn get_udp_transport(&self, fd: usize, py: Python) -> Py<UDPTransport> {
494+
self.udp_transports.pin().get(&fd).unwrap().clone_ref(py)
495+
}
496+
406497
pub(crate) fn log_exception(&self, py: Python, ctx: LogExc) -> PyResult<PyObject> {
407498
let handler = self.exc_handler.read().unwrap();
408499
handler.call1(
@@ -631,6 +722,7 @@ impl EventLoop {
631722
task_factory: RwLock::new(py.None()),
632723
tcp_lstreams: papaya::HashMap::with_capacity(32),
633724
tcp_transports: papaya::HashMap::with_capacity(1024),
725+
udp_transports: papaya::HashMap::with_capacity(1024),
634726
thread_id: atomic::AtomicI64::new(0),
635727
watcher_child: RwLock::new(py.None()),
636728
_asyncgens: weakset(py)?.unbind(),
@@ -1110,6 +1202,23 @@ impl EventLoop {
11101202
self.tcp_transports.pin().contains_key(&fd)
11111203
}
11121204

1205+
fn _udp_conn(
1206+
pyself: Py<Self>,
1207+
py: Python,
1208+
sock: (i32, i32),
1209+
protocol_factory: PyObject,
1210+
remote_addr: Option<(String, u16)>,
1211+
) -> PyResult<(Py<UDPTransport>, PyObject)> {
1212+
let rself = pyself.get();
1213+
let transport = UDPTransport::from_py(py, &pyself, sock, protocol_factory, remote_addr);
1214+
let fd = transport.fd;
1215+
let pytransport = Py::new(py, transport)?;
1216+
let proto = UDPTransport::attach(&pytransport, py)?;
1217+
rself.udp_transports.pin().insert(fd, pytransport.clone_ref(py));
1218+
rself.udp_socket_add(fd, Interest::READABLE);
1219+
Ok((pytransport, proto))
1220+
}
1221+
11131222
fn _sig_add(&self, py: Python, sig: u8, callback: PyObject, args: PyObject, context: PyObject) {
11141223
let handle = Py::new(py, CBHandle::new(callback, args, context)).unwrap();
11151224
self.sig_handlers.pin().insert(sig, handle);

src/io.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@ use std::os::windows::io::RawSocket;
88
use mio::{Interest, Registry, Token, event::Source as MioSource, net::TcpListener};
99

1010
pub(crate) enum Source {
11-
TCPListener(TcpListener),
12-
#[cfg(unix)]
13-
TCPStream(RawFd),
14-
#[cfg(windows)]
15-
TCPStream(RawSocket),
1611
#[cfg(unix)]
1712
FD(RawFd),
1813
#[cfg(windows)]
1914
FD(RawSocket),
15+
TCPListener(TcpListener),
16+
// #[cfg(unix)]
17+
// TCPStream(RawFd),
18+
// #[cfg(windows)]
19+
// TCPStream(RawSocket),
20+
// #[cfg(unix)]
21+
// UDPSocket(RawFd),
22+
// #[cfg(windows)]
23+
// UDPSocket(RawSocket),
2024
}
2125

2226
#[cfg(windows)]
@@ -43,36 +47,33 @@ impl MioSource for Source {
4347
#[inline]
4448
fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
4549
match self {
46-
Self::TCPListener(inner) => inner.register(registry, token, interests),
47-
Self::TCPStream(inner) => SourceFd(inner).register(registry, token, interests),
4850
#[cfg(unix)]
4951
Self::FD(inner) => SourceFd(inner).register(registry, token, interests),
5052
#[cfg(windows)]
5153
Self::FD(inner) => SourceRawSocket(inner).register(registry, token, interests),
54+
Self::TCPListener(inner) => inner.register(registry, token, interests),
5255
}
5356
}
5457

5558
#[inline]
5659
fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
5760
match self {
58-
Self::TCPListener(inner) => inner.reregister(registry, token, interests),
59-
Self::TCPStream(inner) => SourceFd(inner).reregister(registry, token, interests),
6061
#[cfg(unix)]
6162
Self::FD(inner) => SourceFd(inner).reregister(registry, token, interests),
6263
#[cfg(windows)]
6364
Self::FD(inner) => SourceRawSocket(inner).register(registry, token, interests),
65+
Self::TCPListener(inner) => inner.reregister(registry, token, interests),
6466
}
6567
}
6668

6769
#[inline]
6870
fn deregister(&mut self, registry: &Registry) -> std::io::Result<()> {
6971
match self {
70-
Self::TCPListener(inner) => inner.deregister(registry),
71-
Self::TCPStream(inner) => SourceFd(inner).deregister(registry),
7272
#[cfg(unix)]
7373
Self::FD(inner) => SourceFd(inner).deregister(registry),
7474
#[cfg(windows)]
7575
Self::FD(inner) => SourceRawSocket(inner).register(registry, token, interests),
76+
Self::TCPListener(inner) => inner.deregister(registry),
7677
}
7778
}
7879
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ mod server;
1010
mod sock;
1111
mod tcp;
1212
mod time;
13+
mod udp;
1314
mod utils;
1415

1516
pub(crate) fn get_lib_version() -> &'static str {

0 commit comments

Comments
 (0)