diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index f69c6a64c39ae6..475adde088c032 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -1,6 +1,8 @@ """Selector event loop for Unix with signal handling.""" import errno + +import functools import io import itertools import os @@ -67,6 +69,12 @@ def __init__(self, selector=None): else: self._watcher = _ThreadedChildWatcher() + if not hasattr(socket.socket, 'sendmsg'): + delattr(self, 'sock_sendmsg') + + if not hasattr(socket.socket, 'recvmsg'): + delattr(self, 'sock_recvmsg') + def close(self): super().close() if not sys.is_finalizing(): @@ -481,6 +489,104 @@ def _stop_serving(self, sock): logger.error('Unable to clean up listening UNIX socket ' '%r: %r', path, err) + async def sock_sendmsg(self, sock, data, ancdata=[], flags=0, address=None): + """Send datagram (data) and ancillary data to the socket (sock). + + The provided ancillary data is a list of zero or more tuples (data, ancdata, + msg_flags, address). flags represent various conditions and have the same + meaning as for send(). If address is supplied and not None, it sets a destination + address for the message it is the address of the sending socket. + """ + base_events._check_ssl_socket(sock) + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + try: + n = sock.sendmsg([data], ancdata, flags, address) + except (BlockingIOError, InterruptedError): + n = 0 + + if n == len(data): + # all data sent + return + + fut = self.create_future() + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + # use a trick with a list in closure to store a mutable state + handle = self._add_writer(fd, self._sock_sendmsg, fut, sock, + memoryview(data), [n], ancdata, flags, address) + fut.add_done_callback( + functools.partial(self._sock_write_done, fd, handle=handle)) + return await fut + + def _sock_sendmsg(self, fut, sock, view, pos, ancdata, flags, address): + if fut.done(): + # Future cancellation can be scheduled on previous loop iteration + return + start = pos[0] + try: + n = sock.sendmsg([view[start:]], ancdata, flags, address) + except (BlockingIOError, InterruptedError): + return + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + fut.set_exception(exc) + return + + start += n + + if start == len(view): + fut.set_result(None) + else: + pos[0] = start + + async def sock_recvmsg(self, sock, bufsize, ancbufsize=0, flags=0): + """Receive normal data (up to bufsize bytes) and ancillary data from + the socket (sock). The socket must be non-blocking. + + The return value is a tuple of (data, ancdata, msg_flags, address). + data represents the datagram received. ancdata are the ancillary data + (control messages) as a list of tuples (cmsg_level, cmsg_type, cmsg_data), + where cmsg_level and cmsg_type are integers specifying the protocol level + and protocol-specific type respectively, and cmsg_data is a bytes object + holding the associated ancillary data. flags represent various conditions + (bitwise OR) on the received data. + The address is only specified if the receiving socket is unconnected. + Then it is the address of the sending socket. + """ + base_events._check_ssl_socket(sock) + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + try: + return sock.recvmsg(bufsize) + except (BlockingIOError, InterruptedError): + pass + fut = self.create_future() + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_recvmsg, fut, sock, bufsize, ancbufsize, flags) + fut.add_done_callback( + functools.partial(self._sock_read_done, fd, handle=handle)) + return await fut + + def _sock_recvmsg(self, fut, sock, bufsize, ancbufsize, flags): + # _sock_recvmsg() can add itself as an I/O callback if the operation + # can't be done immediately. Don't use it directly, call + # sock_recvmsg(). + if fut.done(): + return + try: + result = sock.recvmsg(bufsize, ancbufsize, flags) + except (BlockingIOError, InterruptedError): + return # try again next time + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + fut.set_exception(exc) + else: + fut.set_result(result) + class _UnixReadPipeTransport(transports.ReadTransport): diff --git a/Lib/test/test_asyncio/test_sock_lowlevel.py b/Lib/test/test_asyncio/test_sock_lowlevel.py index 4f7b9a1dda6b78..3a40d378e4cb2f 100644 --- a/Lib/test/test_asyncio/test_sock_lowlevel.py +++ b/Lib/test/test_asyncio/test_sock_lowlevel.py @@ -1,10 +1,16 @@ +import ctypes + import socket import asyncio import sys +import struct import unittest from asyncio import proactor_events from itertools import cycle, islice + +from ipaddress import IPv4Address +from test.test_socket import requireAttrs from unittest.mock import Mock from test.test_asyncio import utils as test_utils from test import support @@ -427,6 +433,51 @@ def test_recvfrom_into(self): self.loop.run_until_complete( self._basetest_datagram_recvfrom_into(server_address)) + async def _basetest_datagram_sendmsg_recvmsg(self, server_address): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.setblocking(False) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_PKTINFO, 1) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_RECVTOS, 1) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 1) + + data = b'\x01' * 4096 + ancsize = 10240 + + ancillary_data = [(socket.IPPROTO_IP, socket.IP_TOS, b"\x08")] + await self.loop.sock_sendmsg(sock, data, ancillary_data, address=server_address) + rec_data, ancdata_rcv, msg_flags, address = await self.loop.sock_recvmsg( + sock, len(data), ancsize) + + # Sent data is echoed back + self.assertEqual(data, rec_data) + + # ancillary data + self.assertEqual(2, len(ancdata_rcv)) + self.assertTrue(all(a[0] == socket.IPPROTO_IP for a in ancdata_rcv)) + # PKTINFO + ancdata_rcv_pktinfo = [d for d in ancdata_rcv if d[1] == socket.IP_PKTINFO] + self.assertEqual(1, len(ancdata_rcv_pktinfo)) + ancdata_rcv_pktinfo = ancdata_rcv_pktinfo[0] + # Not decoding the data. Just assert length as sanity check. + self.assertEqual(12, len(ancdata_rcv_pktinfo[2])) + # IP_RECVTOS + ancdata_rcv_rectos = [d for d in ancdata_rcv if d[1] == socket.IP_TOS] + self.assertEqual(1, len(ancdata_rcv_rectos)) + ancdata_rcv_rectos = ancdata_rcv_rectos[0] + tos = int.from_bytes(struct.unpack("c", ancdata_rcv_rectos[2])[0], "big") + # the testing server is sending an empty TOS + self.assertEqual(tos, 0) + + self.assertEqual(msg_flags, 0) + self.assertEqual(address[0], '127.0.0.1') + + @requireAttrs(socket.socket, 'recvmsg', 'sendmsg') + @unittest.skipUnless(sys.platform == 'linux', "Using ancillary data that are only available on Linux") + def test_sendmsg_recvmsg(self): + with test_utils.run_udp_echo_server() as server_address: + self.loop.run_until_complete( + self._basetest_datagram_sendmsg_recvmsg(server_address)) + async def _basetest_datagram_sendto_blocking(self, server_address): # Sad path, sock.sendto() raises BlockingIOError # This involves patching sock.sendto() to raise BlockingIOError but