Skip to content

Commit b9ae8e2

Browse files
committed
Fix address and flags for send_fds/recv_fds
1 parent 40a4d88 commit b9ae8e2

File tree

2 files changed

+81
-8
lines changed

2 files changed

+81
-8
lines changed

Lib/socket.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,8 @@ def send_fds(sock, buffers, fds, flags=0, address=None):
563563
import array
564564

565565
return sock.sendmsg(buffers, [(_socket.SOL_SOCKET,
566-
_socket.SCM_RIGHTS, array.array("i", fds))])
566+
_socket.SCM_RIGHTS, array.array("i", fds))],
567+
flags, address)
567568
__all__.append("send_fds")
568569

569570
if hasattr(_socket.socket, "recvmsg"):
@@ -579,7 +580,7 @@ def recv_fds(sock, bufsize, maxfds, flags=0):
579580
# Array of ints
580581
fds = array.array("i")
581582
msg, ancdata, flags, addr = sock.recvmsg(bufsize,
582-
_socket.CMSG_LEN(maxfds * fds.itemsize))
583+
_socket.CMSG_LEN(maxfds * fds.itemsize), flags)
583584
for cmsg_level, cmsg_type, cmsg_data in ancdata:
584585
if (cmsg_level == _socket.SOL_SOCKET and cmsg_type == _socket.SCM_RIGHTS):
585586
fds.frombytes(cmsg_data[:

Lib/test/test_socket.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7037,6 +7037,12 @@ def test_dual_stack_client_v6(self):
70377037
@requireAttrs(socket, "recv_fds")
70387038
@requireAttrs(socket, "AF_UNIX")
70397039
class SendRecvFdsTests(unittest.TestCase):
7040+
def _test_pipe(self, rfd, wfd, msg):
7041+
assert len(msg) < 512
7042+
os.write(wfd, msg)
7043+
data = os.read(rfd, 512)
7044+
self.assertEqual(data, msg)
7045+
70407046
def testSendAndRecvFds(self):
70417047
def close_pipes(pipes):
70427048
for fd1, fd2 in pipes:
@@ -7066,13 +7072,79 @@ def close_fds(fds):
70667072
# don't test addr
70677073

70687074
# test that file descriptors are connected
7069-
for index, fds in enumerate(pipes):
7070-
rfd, wfd = fds
7071-
os.write(wfd, str(index).encode())
7075+
for index, ((_, wfd), rfd) in enumerate(zip(pipes, fds2)):
7076+
self._test_pipe(rfd, wfd, str(index).encode())
7077+
7078+
def test_send_recv_fds_with_addrs(self):
7079+
sock1 = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
7080+
sock2 = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
7081+
rfd, wfd = os.pipe()
7082+
self.addCleanup(os.close, rfd)
7083+
self.addCleanup(os.close, wfd)
7084+
7085+
with tempfile.TemporaryDirectory() as tmpdir, sock1, sock2:
7086+
sock1_addr = os.path.join(tmpdir, "sock1")
7087+
sock2_addr = os.path.join(tmpdir, "sock2")
7088+
sock1.bind(sock1_addr)
7089+
sock2.bind(sock2_addr)
7090+
sock2.setblocking(False)
7091+
7092+
socket.send_fds(sock1, [MSG], [rfd], address=sock2_addr)
7093+
msg, fds, flags, addr = socket.recv_fds(sock2, len(MSG), 1)
7094+
new_rfd = fds[0]
7095+
self.addCleanup(os.close, new_rfd)
7096+
7097+
self.assertEqual(msg, MSG)
7098+
self.assertEqual(len(fds), 1)
7099+
self.assertEqual(addr, sock1_addr)
7100+
7101+
self._test_pipe(new_rfd, wfd, MSG)
7102+
7103+
@requireAttrs(socket, "MSG_PEEK")
7104+
def test_recv_fds_peek(self):
7105+
rfd, wfd = os.pipe()
7106+
self.addCleanup(os.close, rfd)
7107+
self.addCleanup(os.close, wfd)
70727108

7073-
for index, rfd in enumerate(fds2):
7074-
data = os.read(rfd, 100)
7075-
self.assertEqual(data, str(index).encode())
7109+
sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
7110+
with sock1, sock2:
7111+
socket.send_fds(sock1, [MSG], [rfd])
7112+
sock2.setblocking(False)
7113+
7114+
# peek message on sock2
7115+
peek_len = len(MSG) // 2
7116+
msg, fds, flags, addr = socket.recv_fds(sock2, peek_len, 1,
7117+
socket.MSG_PEEK)
7118+
self.addCleanup(os.close, fds[0])
7119+
self.assertEqual(len(msg), peek_len)
7120+
self.assertEqual(msg, MSG[:peek_len])
7121+
self.assertEqual(flags & socket.MSG_TRUNC, socket.MSG_TRUNC)
7122+
self._test_pipe(fds[0], wfd, MSG)
7123+
7124+
# will raise BlockingIOError if MSG_PEEK didn't work
7125+
msg, fds, flags, addr = socket.recv_fds(sock2, len(MSG), 1)
7126+
self.addCleanup(os.close, fds[0])
7127+
self.assertEqual(msg, MSG)
7128+
self._test_pipe(fds[0], wfd, MSG)
7129+
7130+
@requireAttrs(socket, "MSG_DONTWAIT")
7131+
def test_send_fds_dontwait(self):
7132+
rfd, wfd = os.pipe()
7133+
self.addCleanup(os.close, rfd)
7134+
self.addCleanup(os.close, wfd)
7135+
7136+
sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
7137+
with sock1, sock2:
7138+
sock1.setblocking(True)
7139+
with self.assertRaises(BlockingIOError):
7140+
for _ in range(64 * 1024):
7141+
socket.send_fds(sock1, [MSG], [rfd], socket.MSG_DONTWAIT)
7142+
7143+
msg, fds, flags, addr = socket.recv_fds(sock2, len(MSG), 1)
7144+
self.addCleanup(os.close, fds[0])
7145+
7146+
self.assertEqual(msg, MSG)
7147+
self._test_pipe(fds[0], wfd, MSG)
70767148

70777149

70787150
class FreeThreadingTests(unittest.TestCase):

0 commit comments

Comments
 (0)