Skip to content

Commit d8f0e6c

Browse files
committed
Remove reader/writer when sock_recv(), sock_sendall(), etc are cancelled
1 parent b4f188e commit d8f0e6c

File tree

3 files changed

+104
-45
lines changed

3 files changed

+104
-45
lines changed

tests/test_sockets.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,65 @@ def test_socket_close_remove_writer(self):
505505
s.close()
506506
self.assertEqual(s.fileno(), -1)
507507

508+
def test_socket_cancel_sock_recv(self):
509+
def srv_gen(sock):
510+
time.sleep(1.2)
511+
sock.send(b'helo')
512+
513+
async def kill(fut):
514+
await asyncio.sleep(0.2, loop=self.loop)
515+
fut.cancel()
516+
517+
async def client(sock, addr):
518+
await self.loop.sock_connect(sock, addr)
519+
520+
f = asyncio.ensure_future(self.loop.sock_recv(sock, 10),
521+
loop=self.loop)
522+
self.loop.create_task(kill(f))
523+
with self.assertRaises(asyncio.CancelledError):
524+
await f
525+
sock.close()
526+
self.assertEqual(sock.fileno(), -1)
527+
528+
with self.tcp_server(srv_gen) as srv:
529+
530+
sock = socket.socket()
531+
with sock:
532+
sock.setblocking(False)
533+
c = client(sock, srv.addr)
534+
w = asyncio.wait_for(c, timeout=5.0, loop=self.loop)
535+
self.loop.run_until_complete(w)
536+
537+
def test_socket_cancel_sock_sendall(self):
538+
def srv_gen(sock):
539+
time.sleep(1.2)
540+
sock.recv_all(4)
541+
542+
async def kill(fut):
543+
await asyncio.sleep(0.2, loop=self.loop)
544+
fut.cancel()
545+
546+
async def client(sock, addr):
547+
await self.loop.sock_connect(sock, addr)
548+
549+
f = asyncio.ensure_future(
550+
self.loop.sock_sendall(sock, b'helo' * (1024 * 1024 * 50)),
551+
loop=self.loop)
552+
self.loop.create_task(kill(f))
553+
with self.assertRaises(asyncio.CancelledError):
554+
await f
555+
sock.close()
556+
self.assertEqual(sock.fileno(), -1)
557+
558+
with self.tcp_server(srv_gen) as srv:
559+
560+
sock = socket.socket()
561+
with sock:
562+
sock.setblocking(False)
563+
c = client(sock, srv.addr)
564+
w = asyncio.wait_for(c, timeout=5.0, loop=self.loop)
565+
self.loop.run_until_complete(w)
566+
508567

509568
class TestAIOSockets(_TestSockets, tb.AIOTestCase):
510569
pass

uvloop/loop.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ cdef class Loop:
167167
cdef _fileobj_to_fd(self, fileobj)
168168
cdef _ensure_fd_no_transport(self, fd)
169169

170+
cdef _new_reader_future(self, sock)
171+
cdef _new_writer_future(self, sock)
170172
cdef _add_reader(self, fd, Handle handle)
171173
cdef _remove_reader(self, fd)
172174

@@ -178,7 +180,7 @@ cdef class Loop:
178180
cdef _sock_sendall(self, fut, sock, data)
179181
cdef _sock_accept(self, fut, sock)
180182

181-
cdef _sock_connect(self, fut, sock, address)
183+
cdef _sock_connect(self, sock, address)
182184
cdef _sock_connect_cb(self, fut, sock, address)
183185

184186
cdef _sock_set_reuseport(self, int fd)

uvloop/loop.pyx

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -794,14 +794,28 @@ cdef class Loop:
794794
nr.query(addr, flags)
795795
return fut
796796

797+
cdef _new_reader_future(self, sock):
798+
def _on_cancel(fut):
799+
if fut.cancelled():
800+
self._remove_reader(sock)
801+
802+
fut = self._new_future()
803+
fut.add_done_callback(_on_cancel)
804+
return fut
805+
806+
cdef _new_writer_future(self, sock):
807+
def _on_cancel(fut):
808+
if fut.cancelled():
809+
self._remove_writer(sock)
810+
811+
fut = self._new_future()
812+
fut.add_done_callback(_on_cancel)
813+
return fut
814+
797815
cdef _sock_recv(self, fut, sock, n):
798816
cdef:
799817
Handle handle
800818

801-
if fut.cancelled():
802-
self._remove_reader(sock)
803-
return
804-
805819
try:
806820
data = sock.recv(n)
807821
except (BlockingIOError, InterruptedError):
@@ -819,10 +833,6 @@ cdef class Loop:
819833
cdef:
820834
Handle handle
821835

822-
if fut.cancelled():
823-
self._remove_reader(sock)
824-
return
825-
826836
try:
827837
data = sock.recv_into(buf)
828838
except (BlockingIOError, InterruptedError):
@@ -841,10 +851,6 @@ cdef class Loop:
841851
Handle handle
842852
int n
843853

844-
if fut.cancelled():
845-
self._remove_writer(sock)
846-
return
847-
848854
try:
849855
n = sock.send(data)
850856
except (BlockingIOError, InterruptedError):
@@ -878,10 +884,6 @@ cdef class Loop:
878884
cdef:
879885
Handle handle
880886

881-
if fut.cancelled():
882-
self._remove_reader(sock)
883-
return
884-
885887
try:
886888
conn, address = sock.accept()
887889
conn.setblocking(False)
@@ -896,31 +898,29 @@ cdef class Loop:
896898
fut.set_result((conn, address))
897899
self._remove_reader(sock)
898900

899-
cdef _sock_connect(self, fut, sock, address):
901+
cdef _sock_connect(self, sock, address):
900902
cdef:
901903
Handle handle
902904

903905
try:
904906
sock.connect(address)
905907
except (BlockingIOError, InterruptedError):
906-
# Issue #23618: When the C function connect() fails with EINTR, the
907-
# connection runs in background. We have to wait until the socket
908-
# becomes writable to be notified when the connection succeed or
909-
# fails.
910-
fut.add_done_callback(lambda fut: self._remove_writer(sock))
908+
pass
909+
else:
910+
return
911911

912-
handle = new_MethodHandle3(
913-
self,
914-
"Loop._sock_connect",
915-
<method3_t>self._sock_connect_cb,
916-
self,
917-
fut, sock, address)
912+
fut = self._new_future()
913+
fut.add_done_callback(lambda fut: self._remove_writer(sock))
918914

919-
self._add_writer(sock, handle)
920-
except Exception as exc:
921-
fut.set_exception(exc)
922-
else:
923-
fut.set_result(None)
915+
handle = new_MethodHandle3(
916+
self,
917+
"Loop._sock_connect",
918+
<method3_t>self._sock_connect_cb,
919+
self,
920+
fut, sock, address)
921+
922+
self._add_writer(sock, handle)
923+
return fut
924924

925925
cdef _sock_connect_cb(self, fut, sock, address):
926926
if fut.cancelled():
@@ -2087,7 +2087,7 @@ cdef class Loop:
20872087
if self._debug and sock.gettimeout() != 0:
20882088
raise ValueError("the socket must be non-blocking")
20892089

2090-
fut = self._new_future()
2090+
fut = self._new_reader_future(sock)
20912091
handle = new_MethodHandle3(
20922092
self,
20932093
"Loop._sock_recv",
@@ -2112,7 +2112,7 @@ cdef class Loop:
21122112
if self._debug and sock.gettimeout() != 0:
21132113
raise ValueError("the socket must be non-blocking")
21142114

2115-
fut = self._new_future()
2115+
fut = self._new_reader_future(sock)
21162116
handle = new_MethodHandle3(
21172117
self,
21182118
"Loop._sock_recv_into",
@@ -2162,7 +2162,7 @@ cdef class Loop:
21622162
data = memoryview(data)
21632163
data = data[n:]
21642164

2165-
fut = self._new_future()
2165+
fut = self._new_writer_future(sock)
21662166
handle = new_MethodHandle3(
21672167
self,
21682168
"Loop._sock_sendall",
@@ -2191,7 +2191,7 @@ cdef class Loop:
21912191
if self._debug and sock.gettimeout() != 0:
21922192
raise ValueError("the socket must be non-blocking")
21932193

2194-
fut = self._new_future()
2194+
fut = self._new_reader_future(sock)
21952195
handle = new_MethodHandle2(
21962196
self,
21972197
"Loop._sock_accept",
@@ -2212,15 +2212,13 @@ cdef class Loop:
22122212

22132213
socket_inc_io_ref(sock)
22142214
try:
2215-
fut = self._new_future()
22162215
if sock.family == uv.AF_UNIX:
2217-
self._sock_connect(fut, sock, address)
2216+
fut = self._sock_connect(sock, address)
2217+
else:
2218+
_, _, _, _, address = (await self.getaddrinfo(*address[:2]))[0]
2219+
fut = self._sock_connect(sock, address)
2220+
if fut is not None:
22182221
await fut
2219-
return
2220-
2221-
_, _, _, _, address = (await self.getaddrinfo(*address[:2]))[0]
2222-
self._sock_connect(fut, sock, address)
2223-
await fut
22242222
finally:
22252223
socket_dec_io_ref(sock)
22262224

0 commit comments

Comments
 (0)