Skip to content

Commit 92fb0fc

Browse files
committed
Make loop methods reject socket kinds they do not support.
More specifically: * loop.create_connection() and loop.create_server() can accept AF_INET or AF_INET6 SOCK_STREAM sockets; * loop.create_datagram_endpoint() can accept only SOCK_DGRAM sockets; * loop.connect_accepted_socket() can accept only SOCK_STREAM sockets; * fixed a bug in create_unix_server() and create_unix_connection() to properly check for SOCK_STREAM sockets on Linux; * fixed static DNS resolution to decline socket types that aren't strictly equal to SOCK_STREAM or SOCK_DGRAM. On Linux socket type can be a bit mask, and we should let system getaddrinfo() to deal with it.
1 parent 84220e6 commit 92fb0fc

File tree

7 files changed

+110
-18
lines changed

7 files changed

+110
-18
lines changed

tests/test_dns.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def test_getaddrinfo_19(self):
130130
self._test_getaddrinfo('::1', 80)
131131
self._test_getaddrinfo('::1', 80, type=socket.SOCK_STREAM)
132132

133+
def test_getaddrinfo_20(self):
134+
self._test_getaddrinfo('127.0.0.1', 80)
135+
self._test_getaddrinfo('127.0.0.1', 80, type=socket.SOCK_STREAM)
136+
133137
######
134138

135139
def test_getnameinfo_1(self):

tests/test_tcp.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,35 @@ def client():
806806
tr, _ = f.result()
807807
tr.close()
808808

809+
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
810+
def test_create_connection_wrong_sock(self):
811+
sock = socket.socket(socket.AF_UNIX)
812+
with sock:
813+
coro = self.loop.create_connection(MyBaseProto, sock=sock)
814+
with self.assertRaisesRegex(ValueError,
815+
'A TCP Stream Socket was expected'):
816+
self.loop.run_until_complete(coro)
817+
818+
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
819+
def test_create_server_wrong_sock(self):
820+
sock = socket.socket(socket.AF_UNIX)
821+
with sock:
822+
coro = self.loop.create_server(MyBaseProto, sock=sock)
823+
with self.assertRaisesRegex(ValueError,
824+
'A TCP Stream Socket was expected'):
825+
self.loop.run_until_complete(coro)
826+
827+
@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
828+
'no socket.SOCK_NONBLOCK (linux only)')
829+
def test_create_server_stream_bittype(self):
830+
sock = socket.socket(
831+
socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
832+
with sock:
833+
coro = self.loop.create_server(lambda: None, sock=sock)
834+
srv = self.loop.run_until_complete(coro)
835+
srv.close()
836+
self.loop.run_until_complete(srv.wait_closed())
837+
809838

810839
class Test_AIO_TCP(_TestTCP, tb.AIOTestCase):
811840
pass

tests/test_udp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,14 @@ def test_create_datagram_endpoint_sock(self):
131131

132132

133133
class Test_UV_UDP(_TestUDP, tb.UVTestCase):
134-
pass
134+
135+
def test_create_datagram_endpoint_wrong_sock(self):
136+
sock = socket.socket(socket.AF_INET)
137+
with sock:
138+
coro = self.loop.create_datagram_endpoint(lambda: None, sock=sock)
139+
with self.assertRaisesRegex(ValueError,
140+
'A UDP Socket was expected'):
141+
self.loop.run_until_complete(coro)
135142

136143

137144
class Test_AIO_UDP(_TestUDP, tb.AIOTestCase):

tests/test_unix.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import socket
44
import tempfile
5+
import unittest
56

67
from uvloop import _testbase as tb
78

@@ -354,6 +355,33 @@ async def test(sock):
354355
with s1, s2:
355356
self.loop.run_until_complete(test(s1))
356357

358+
def test_create_unix_server_path_dgram(self):
359+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
360+
with sock:
361+
coro = self.loop.create_unix_server(lambda: None, path=None,
362+
sock=sock)
363+
with self.assertRaisesRegex(ValueError,
364+
'A UNIX Domain Stream.*was expected'):
365+
self.loop.run_until_complete(coro)
366+
367+
@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
368+
'no socket.SOCK_NONBLOCK (linux only)')
369+
def test_create_unix_server_path_stream_bittype(self):
370+
sock = socket.socket(
371+
socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
372+
with tempfile.NamedTemporaryFile() as file:
373+
fn = file.name
374+
try:
375+
with sock:
376+
sock.bind(fn)
377+
coro = self.loop.create_unix_server(lambda: None, path=None,
378+
sock=sock)
379+
srv = self.loop.run_until_complete(coro)
380+
srv.close()
381+
self.loop.run_until_complete(srv.wait_closed())
382+
finally:
383+
os.unlink(fn)
384+
357385

358386
class Test_AIO_Unix(_TestUnix, tb.AIOTestCase):
359387
pass

uvloop/dns.pyx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,12 @@ cdef __static_getaddrinfo(object host, object port,
126126
if proto not in {0, uv.IPPROTO_TCP, uv.IPPROTO_UDP}:
127127
raise LookupError
128128

129-
type &= ~_SOCKET_TYPE_MASK
130129
if type == uv.SOCK_STREAM:
130+
# Linux only:
131+
# getaddrinfo() can raise when socket.type is a bit mask.
132+
# So if socket.type is a bit mask of SOCK_STREAM, and say
133+
# SOCK_NONBLOCK, we simply return None, which will trigger
134+
# a call to getaddrinfo() letting it process this request.
131135
proto = uv.IPPROTO_TCP
132136
elif type == uv.SOCK_DGRAM:
133137
proto = uv.IPPROTO_UDP

uvloop/includes/stdlib.pxi

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,6 @@ cdef int socket_EAI_SERVICE = getattr(socket, 'EAI_SERVICE', -1)
8383
cdef int socket_EAI_SOCKTYPE = getattr(socket, 'EAI_SOCKTYPE', -1)
8484

8585

86-
cdef int _SOCKET_TYPE_MASK = 0
87-
if hasattr(socket, 'SOCK_NONBLOCK'):
88-
_SOCKET_TYPE_MASK |= socket.SOCK_NONBLOCK
89-
if hasattr(socket, 'SOCK_CLOEXEC'):
90-
_SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC
91-
92-
9386
cdef str os_name = os.name
9487
cdef os_environ = os.environ
9588
cdef os_dup = os.dup

uvloop/loop.pyx

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,24 @@ include "includes/stdlib.pxi"
4040
include "errors.pyx"
4141

4242

43+
cdef _is_sock_ip(sock_family):
44+
return sock_family == uv.AF_INET or sock_family == uv.AF_INET6
45+
46+
47+
cdef _is_sock_stream(sock_type):
48+
# Linux's socket.type is a bitmask that can include extra info
49+
# about socket, therefore we can't do simple
50+
# `sock_type == socket.SOCK_STREAM`.
51+
return (sock_type & uv.SOCK_STREAM) == uv.SOCK_STREAM
52+
53+
54+
cdef _is_sock_dgram(sock_type):
55+
# Linux's socket.type is a bitmask that can include extra info
56+
# about socket, therefore we can't do simple
57+
# `sock_type == socket.SOCK_DGRAM`.
58+
return (sock_type & uv.SOCK_DGRAM) == uv.SOCK_DGRAM
59+
60+
4361
cdef isfuture(obj):
4462
if aio_isfuture is None:
4563
return isinstance(obj, aio_Future)
@@ -1322,6 +1340,10 @@ cdef class Loop:
13221340
else:
13231341
if sock is None:
13241342
raise ValueError('Neither host/port nor sock were specified')
1343+
if (not _is_sock_stream(sock.type) or
1344+
not _is_sock_ip(sock.family)):
1345+
raise ValueError(
1346+
'A TCP Stream Socket was expected, got {!r}'.format(sock))
13251347
tcp = TCPServer.new(self, protocol_factory, server, ssl,
13261348
uv.AF_UNSPEC)
13271349

@@ -1505,6 +1527,10 @@ cdef class Loop:
15051527
if sock is None:
15061528
raise ValueError(
15071529
'host and port was not specified and no sock specified')
1530+
if (not _is_sock_stream(sock.type) or
1531+
not _is_sock_ip(sock.family)):
1532+
raise ValueError(
1533+
'A TCP Stream Socket was expected, got {!r}'.format(sock))
15081534

15091535
waiter = self._new_future()
15101536
tr = TCPTransport.new(self, protocol, None, waiter)
@@ -1578,8 +1604,6 @@ cdef class Loop:
15781604
if ssl is not None and not isinstance(ssl, ssl_SSLContext):
15791605
raise TypeError('ssl argument must be an SSLContext or None')
15801606

1581-
pipe = UnixServer.new(self, protocol_factory, server, ssl)
1582-
15831607
if path is not None:
15841608
if sock is not None:
15851609
raise ValueError(
@@ -1594,7 +1618,6 @@ cdef class Loop:
15941618
try:
15951619
sock.bind(path)
15961620
except OSError as exc:
1597-
pipe._close()
15981621
sock.close()
15991622
if exc.errno == errno.EADDRINUSE:
16001623
# Let's improve the error message by adding
@@ -1604,7 +1627,6 @@ cdef class Loop:
16041627
else:
16051628
raise
16061629
except:
1607-
pipe._close()
16081630
sock.close()
16091631
raise
16101632

@@ -1613,11 +1635,13 @@ cdef class Loop:
16131635
raise ValueError(
16141636
'path was not specified, and no sock specified')
16151637

1616-
if sock.family != uv.AF_UNIX or sock.type != uv.SOCK_STREAM:
1638+
if sock.family != uv.AF_UNIX or not _is_sock_stream(sock.type):
16171639
raise ValueError(
16181640
'A UNIX Domain Stream Socket was expected, got {!r}'
16191641
.format(sock))
16201642

1643+
pipe = UnixServer.new(self, protocol_factory, server, ssl)
1644+
16211645
try:
16221646
# See a comment on os_dup in create_connection
16231647
fileno = os_dup(sock.fileno())
@@ -1686,7 +1710,7 @@ cdef class Loop:
16861710
if sock is None:
16871711
raise ValueError('no path and sock were specified')
16881712

1689-
if sock.family != uv.AF_UNIX or sock.type != uv.SOCK_STREAM:
1713+
if sock.family != uv.AF_UNIX or not _is_sock_stream(sock.type):
16901714
raise ValueError(
16911715
'A UNIX Domain Stream Socket was expected, got {!r}'
16921716
.format(sock))
@@ -1989,9 +2013,9 @@ cdef class Loop:
19892013

19902014
if ssl is not None and not isinstance(ssl, ssl_SSLContext):
19912015
raise TypeError('ssl argument must be an SSLContext or None')
1992-
1993-
if sock.type != uv.SOCK_STREAM:
1994-
raise ValueError('invalid socket type, SOCK_STREAM expected')
2016+
if not _is_sock_stream(sock.type):
2017+
raise ValueError(
2018+
'A Stream Socket was expected, got {!r}'.format(sock))
19952019

19962020
# See a comment on os_dup in create_connection
19972021
fileno = os_dup(sock.fileno())
@@ -2296,6 +2320,9 @@ cdef class Loop:
22962320
system.addrinfo * rai
22972321

22982322
if sock is not None:
2323+
if not _is_sock_dgram(sock.type):
2324+
raise ValueError(
2325+
'A UDP Socket was expected, got {!r}'.format(sock))
22992326
if (local_addr or remote_addr or
23002327
family or proto or flags or
23012328
reuse_address or reuse_port or allow_broadcast):

0 commit comments

Comments
 (0)