diff --git a/Doc/library/socket.rst b/Doc/library/socket.rst index 935d4a85342876..decb9e8ccbe289 100644 --- a/Doc/library/socket.rst +++ b/Doc/library/socket.rst @@ -1373,10 +1373,6 @@ The :mod:`socket` module also offers various network-related services: .. versionadded:: 3.9 - .. note:: - - Any truncated integers at the end of the list of file descriptors. - .. _socket-objects: diff --git a/Lib/socket.py b/Lib/socket.py index be37c24d6174a2..2935c20bb0849f 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -563,7 +563,7 @@ def send_fds(sock, buffers, fds, flags=0, address=None): import array return sock.sendmsg(buffers, [(_socket.SOL_SOCKET, - _socket.SCM_RIGHTS, array.array("i", fds))]) + _socket.SCM_RIGHTS, array.array("i", fds))], flags, address) __all__.append("send_fds") if hasattr(_socket.socket, "recvmsg"): @@ -578,14 +578,15 @@ def recv_fds(sock, bufsize, maxfds, flags=0): # Array of ints fds = array.array("i") - msg, ancdata, flags, addr = sock.recvmsg(bufsize, - _socket.CMSG_LEN(maxfds * fds.itemsize)) + msg, ancdata, msg_flags, addr = sock.recvmsg(bufsize, + _socket.CMSG_LEN(maxfds * fds.itemsize), flags) for cmsg_level, cmsg_type, cmsg_data in ancdata: if (cmsg_level == _socket.SOL_SOCKET and cmsg_type == _socket.SCM_RIGHTS): + # Append data, ignoring any truncated integers at the end. fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) - return msg, list(fds), flags, addr + return msg, list(fds), msg_flags, addr __all__.append("recv_fds") if hasattr(_socket.socket, "share"): diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 663fa50c086c13..1e8501a0744f82 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -6994,42 +6994,131 @@ def test_dual_stack_client_v6(self): @requireAttrs(socket, "recv_fds") @requireAttrs(socket, "AF_UNIX") class SendRecvFdsTests(unittest.TestCase): - def testSendAndRecvFds(self): - def close_pipes(pipes): - for fd1, fd2 in pipes: - os.close(fd1) - os.close(fd2) - def close_fds(fds): - for fd in fds: - os.close(fd) + def setUp(self): + # make 13 pipes, reading side in rds, writing side in wds + rds, wds = zip(*(os.pipe() for _ in range(13))) + # checks to verify that test setup is OK + self.assertEqual(len(rds), len(wds)) + self.check_pipes_connected(wds, rds) + + self.rds, self.wds = rds, wds + + def tearDown(self): + self.close_fds(self.rds + self.wds) + + @staticmethod + def close_fds(fds): + for fd in fds: + os.close(fd) - # send 10 file descriptors - pipes = [os.pipe() for _ in range(10)] - self.addCleanup(close_pipes, pipes) - fds = [rfd for rfd, wfd in pipes] + def check_pipes_connected(self, wds, rds): + for index, fd in enumerate(wds): + os.write(fd, str(index).encode()) + + for index, fd in enumerate(rds): + os.set_blocking(fd, False) + data = os.read(fd, 1024) + self.assertEqual(data, str(index).encode()) + + def testSendAndRecvFds(self): + """basic test over a local socket pair""" # use a UNIX socket pair to exchange file descriptors locally sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) with sock1, sock2: - socket.send_fds(sock1, [MSG], fds) - # request more data and file descriptors than expected - msg, fds2, flags, addr = socket.recv_fds(sock2, len(MSG) * 2, len(fds) * 2) - self.addCleanup(close_fds, fds2) + # send the reading fds over sock1 + socket.send_fds(sock1, [MSG], self.rds) + # receive them on sock2 + # allocate space for more data and file descriptors than expected + msg, rds2, msg_flags, addr = socket.recv_fds( + sock2, len(MSG) * 2, len(self.rds) * 2 + ) + self.addCleanup(self.close_fds, rds2) self.assertEqual(msg, MSG) - self.assertEqual(len(fds2), len(fds)) - self.assertEqual(flags, 0) - # don't test addr + self.assertEqual(len(rds2), len(self.rds)) + self.assertEqual(msg_flags, 0) + # addr contains no useful info and is not checked + self.check_pipes_connected(self.wds, rds2) - # test that file descriptors are connected - for index, fds in enumerate(pipes): - rfd, wfd = fds - os.write(wfd, str(index).encode()) + def testRecvFlags(self): + """same as testSendAndRecvFds but set recv_fds flags to MSG_PEEK""" + + # use a UNIX socket pair to exchange file descriptors locally + # set to non blocking to avoid hangs on errors + with socket_setdefaulttimeout(0): + sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + with sock1, sock2: + # send the reading fds over sock1 + socket.send_fds(sock1, [MSG], self.rds) + # receive them on sock2, with socket.MSG_PEEK + # allocate space for more data and file descriptors than expected + msg, rds2, msg_flags, addr = socket.recv_fds( + sock2, len(MSG) * 2, len(self.rds) * 2, socket.MSG_PEEK + ) + + self.assertEqual(msg, MSG) + self.assertEqual(len(rds2), len(self.rds)) + # msg_flags can be 0 or socket.MSG_PEEK: + self.assertEqual(msg_flags & ~socket.MSG_PEEK, 0) + # addr contains no useful info and is not checked + if msg_flags == 0: + # rds2 are open and connected + self.addCleanup(self.close_fds, rds2) + self.check_pipes_connected(self.wds, rds2) + elif msg_flags == socket.MSG_PEEK: + # rds2 are not open fds and should be all 0s + self.assertEqual(rds2, [0]*len(rds2)) + else: + assert False, "message msg_flags has unexpected value" + + # receive again without socket.MSG_PEEK + msg, rds2, msg_flags, addr = socket.recv_fds( + sock2, len(MSG) * 2, len(self.rds) * 2 + ) + self.addCleanup(self.close_fds, rds2) + + self.assertEqual(msg, MSG) + self.assertEqual(len(rds2), len(self.rds)) + self.assertEqual(msg_flags, 0) + # addr contains no useful info and is not checked + self.check_pipes_connected(self.wds, rds2) + + # check that there is no more data waiting + with self.assertRaises(BlockingIOError): + msg, rds2, msg_flags, addr = socket.recv_fds( + sock2, len(MSG) * 2, len(self.rds) * 2 + ) + + def testSendAddress(self): + """test receiving a DGRAM over a bound socket, sending with the + 'address' argument of send_fds""" + + # create a bound socket and use a second unbound socket to test + # the address parameter of send_fds + with ( + tempfile.TemporaryDirectory() as tmpdir, + socket.socket(family=socket.AF_UNIX, type=socket.SOCK_DGRAM) as sock1, + socket.socket(family=socket.AF_UNIX, type=socket.SOCK_DGRAM) as sock2 + ): + # bind sock1 + sockpth = os.path.join(tmpdir, "SOCK") + sock1.bind(sockpth) + # send from sock2 + socket.send_fds(sock2, [MSG], self.rds, address=sockpth) + # receive on sock1 + msg, rds2, msg_flags, addr = socket.recv_fds( + sock1, len(MSG) * 2, len(self.rds) * 2 + ) + self.addCleanup(self.close_fds, rds2) + + self.assertEqual(msg, MSG) + self.assertEqual(len(rds2), len(self.rds)) + self.assertEqual(msg_flags, 0) + # addr contains no useful info and is not checked + self.check_pipes_connected(self.wds, rds2) - for index, rfd in enumerate(fds2): - data = os.read(rfd, 100) - self.assertEqual(data, str(index).encode()) def setUpModule(): diff --git a/Misc/NEWS.d/next/Library/2023-08-22-01-08-12.gh-issue-107898.eH3Y4r.rst b/Misc/NEWS.d/next/Library/2023-08-22-01-08-12.gh-issue-107898.eH3Y4r.rst new file mode 100644 index 00000000000000..c616f717bad38d --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-08-22-01-08-12.gh-issue-107898.eH3Y4r.rst @@ -0,0 +1,2 @@ +Make :func:`socket.send_fds` and :func:`socket.recv_fds` honour optional +arguments. Patch by Stefano Miccoli.