Skip to content

Commit f691212

Browse files
committed
Fix context in protocol callbacks (#348)
This is a combined fix to correct contexts from which protocal callbacks are invoked. In short, callbacks like data_received() should always be invoked from consistent contexts which are copied from the context where the underlying UVHandle is created or started. The new test case covers also asyncio, but skipping the failing ones.
1 parent 7b202cc commit f691212

19 files changed

+791
-132
lines changed

tests/test_context.py

Lines changed: 601 additions & 22 deletions
Large diffs are not rendered by default.

tests/test_sockets.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,10 @@ def test_socket_sync_remove_and_immediately_close(self):
190190
self.loop.run_until_complete(asyncio.sleep(0.01))
191191

192192
def test_sock_cancel_add_reader_race(self):
193-
if self.is_asyncio_loop():
194-
if sys.version_info[:2] == (3, 8):
195-
# asyncio 3.8.x has a regression; fixed in 3.9.0
196-
# tracked in https://bugs.python.org/issue30064
197-
raise unittest.SkipTest()
193+
if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8):
194+
# asyncio 3.8.x has a regression; fixed in 3.9.0
195+
# tracked in https://bugs.python.org/issue30064
196+
raise unittest.SkipTest()
198197

199198
srv_sock_conn = None
200199

@@ -247,8 +246,8 @@ async def send_server_data():
247246
self.loop.run_until_complete(server())
248247

249248
def test_sock_send_before_cancel(self):
250-
if self.is_asyncio_loop() and sys.version_info[:3] == (3, 8, 0):
251-
# asyncio 3.8.0 seems to have a regression;
249+
if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8):
250+
# asyncio 3.8.x has a regression; fixed in 3.9.0
252251
# tracked in https://bugs.python.org/issue30064
253252
raise unittest.SkipTest()
254253

uvloop/cbhandles.pyx

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -333,71 +333,72 @@ cdef new_Handle(Loop loop, object callback, object args, object context):
333333
return handle
334334

335335

336-
cdef new_MethodHandle(Loop loop, str name, method_t callback, object ctx):
336+
cdef new_MethodHandle(Loop loop, str name, method_t callback, object context,
337+
object bound_to):
337338
cdef Handle handle
338339
handle = Handle.__new__(Handle)
339340
handle._set_loop(loop)
340-
handle._set_context(None)
341+
handle._set_context(context)
341342

342343
handle.cb_type = 2
343344
handle.meth_name = name
344345

345346
handle.callback = <void*> callback
346-
handle.arg1 = ctx
347+
handle.arg1 = bound_to
347348

348349
return handle
349350

350351

351-
cdef new_MethodHandle1(Loop loop, str name, method1_t callback,
352-
object ctx, object arg):
352+
cdef new_MethodHandle1(Loop loop, str name, method1_t callback, object context,
353+
object bound_to, object arg):
353354

354355
cdef Handle handle
355356
handle = Handle.__new__(Handle)
356357
handle._set_loop(loop)
357-
handle._set_context(None)
358+
handle._set_context(context)
358359

359360
handle.cb_type = 3
360361
handle.meth_name = name
361362

362363
handle.callback = <void*> callback
363-
handle.arg1 = ctx
364+
handle.arg1 = bound_to
364365
handle.arg2 = arg
365366

366367
return handle
367368

368369

369-
cdef new_MethodHandle2(Loop loop, str name, method2_t callback, object ctx,
370-
object arg1, object arg2):
370+
cdef new_MethodHandle2(Loop loop, str name, method2_t callback, object context,
371+
object bound_to, object arg1, object arg2):
371372

372373
cdef Handle handle
373374
handle = Handle.__new__(Handle)
374375
handle._set_loop(loop)
375-
handle._set_context(None)
376+
handle._set_context(context)
376377

377378
handle.cb_type = 4
378379
handle.meth_name = name
379380

380381
handle.callback = <void*> callback
381-
handle.arg1 = ctx
382+
handle.arg1 = bound_to
382383
handle.arg2 = arg1
383384
handle.arg3 = arg2
384385

385386
return handle
386387

387388

388-
cdef new_MethodHandle3(Loop loop, str name, method3_t callback, object ctx,
389-
object arg1, object arg2, object arg3):
389+
cdef new_MethodHandle3(Loop loop, str name, method3_t callback, object context,
390+
object bound_to, object arg1, object arg2, object arg3):
390391

391392
cdef Handle handle
392393
handle = Handle.__new__(Handle)
393394
handle._set_loop(loop)
394-
handle._set_context(None)
395+
handle._set_context(context)
395396

396397
handle.cb_type = 5
397398
handle.meth_name = name
398399

399400
handle.callback = <void*> callback
400-
handle.arg1 = ctx
401+
handle.arg1 = bound_to
401402
handle.arg2 = arg1
402403
handle.arg3 = arg2
403404
handle.arg4 = arg3

uvloop/handles/basetransport.pyx

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ cdef class UVBaseTransport(UVSocketHandle):
2626
new_MethodHandle(self._loop,
2727
"UVTransport._call_connection_made",
2828
<method_t>self._call_connection_made,
29+
self.context,
2930
self))
3031

3132
cdef inline _schedule_call_connection_lost(self, exc):
3233
self._loop._call_soon_handle(
3334
new_MethodHandle1(self._loop,
3435
"UVTransport._call_connection_lost",
3536
<method1_t>self._call_connection_lost,
37+
self.context,
3638
self, exc))
3739

3840
cdef _fatal_error(self, exc, throw, reason=None):
@@ -66,7 +68,9 @@ cdef class UVBaseTransport(UVSocketHandle):
6668
if not self._protocol_paused:
6769
self._protocol_paused = 1
6870
try:
69-
self._protocol.pause_writing()
71+
# _maybe_pause_protocol() is always triggered from user-calls,
72+
# so we must copy the context to avoid entering context twice
73+
self.context.copy().run(self._protocol.pause_writing)
7074
except (KeyboardInterrupt, SystemExit):
7175
raise
7276
except BaseException as exc:
@@ -84,7 +88,10 @@ cdef class UVBaseTransport(UVSocketHandle):
8488
if self._protocol_paused and size <= self._low_water:
8589
self._protocol_paused = 0
8690
try:
87-
self._protocol.resume_writing()
91+
# We're copying the context to avoid entering context twice,
92+
# even though it's not always necessary to copy - it's easier
93+
# to copy here than passing down a copied context.
94+
self.context.copy().run(self._protocol.resume_writing)
8895
except (KeyboardInterrupt, SystemExit):
8996
raise
9097
except BaseException as exc:

uvloop/handles/handle.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ cdef class UVHandle:
55
readonly _source_traceback
66
bint _closed
77
bint _inited
8+
object context
89

910
# Added to enable current UDPTransport implementation,
1011
# which doesn't use libuv handles.

uvloop/handles/pipe.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ cdef class UnixTransport(UVStream):
1414

1515
@staticmethod
1616
cdef UnixTransport new(Loop loop, object protocol, Server server,
17-
object waiter)
17+
object waiter, object context)
1818

1919
cdef connect(self, char* addr)
2020

uvloop/handles/pipe.pyx

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ cdef class UnixServer(UVStreamServer):
7373

7474
self._mark_as_open()
7575

76-
cdef UVStream _make_new_transport(self, object protocol, object waiter):
76+
cdef UVStream _make_new_transport(self, object protocol, object waiter,
77+
object context):
7778
cdef UnixTransport tr
78-
tr = UnixTransport.new(self._loop, protocol, self._server, waiter)
79+
tr = UnixTransport.new(self._loop, protocol, self._server, waiter,
80+
context)
7981
return <UVStream>tr
8082

8183

@@ -84,11 +86,11 @@ cdef class UnixTransport(UVStream):
8486

8587
@staticmethod
8688
cdef UnixTransport new(Loop loop, object protocol, Server server,
87-
object waiter):
89+
object waiter, object context):
8890

8991
cdef UnixTransport handle
9092
handle = UnixTransport.__new__(UnixTransport)
91-
handle._init(loop, protocol, server, waiter)
93+
handle._init(loop, protocol, server, waiter, context)
9294
__pipe_init_uv_handle(<UVStream>handle, loop)
9395
return handle
9496

@@ -112,7 +114,9 @@ cdef class ReadUnixTransport(UVStream):
112114
object waiter):
113115
cdef ReadUnixTransport handle
114116
handle = ReadUnixTransport.__new__(ReadUnixTransport)
115-
handle._init(loop, protocol, server, waiter)
117+
# This is only used in connect_read_pipe() and subprocess_shell/exec()
118+
# directly, we could simply copy the current context.
119+
handle._init(loop, protocol, server, waiter, Context_CopyCurrent())
116120
__pipe_init_uv_handle(<UVStream>handle, loop)
117121
return handle
118122

@@ -162,7 +166,9 @@ cdef class WriteUnixTransport(UVStream):
162166
# close the transport.
163167
handle._close_on_read_error()
164168

165-
handle._init(loop, protocol, server, waiter)
169+
# This is only used in connect_write_pipe() and subprocess_shell/exec()
170+
# directly, we could simply copy the current context.
171+
handle._init(loop, protocol, server, waiter, Context_CopyCurrent())
166172
__pipe_init_uv_handle(<UVStream>handle, loop)
167173
return handle
168174

uvloop/handles/process.pyx

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ cdef class UVProcess(UVHandle):
1010
self._fds_to_close = set()
1111
self._preexec_fn = None
1212
self._restore_signals = True
13+
self.context = Context_CopyCurrent()
1314

1415
cdef _close_process_handle(self):
1516
# XXX: This is a workaround for a libuv bug:
@@ -364,7 +365,8 @@ cdef class UVProcessTransport(UVProcess):
364365
UVProcess._on_exit(self, exit_status, term_signal)
365366

366367
if self._stdio_ready:
367-
self._loop.call_soon(self._protocol.process_exited)
368+
self._loop.call_soon(self._protocol.process_exited,
369+
context=self.context)
368370
else:
369371
self._pending_calls.append((_CALL_PROCESS_EXITED, None, None))
370372

@@ -383,14 +385,16 @@ cdef class UVProcessTransport(UVProcess):
383385

384386
cdef _pipe_connection_lost(self, int fd, exc):
385387
if self._stdio_ready:
386-
self._loop.call_soon(self._protocol.pipe_connection_lost, fd, exc)
388+
self._loop.call_soon(self._protocol.pipe_connection_lost, fd, exc,
389+
context=self.context)
387390
self._try_finish()
388391
else:
389392
self._pending_calls.append((_CALL_PIPE_CONNECTION_LOST, fd, exc))
390393

391394
cdef _pipe_data_received(self, int fd, data):
392395
if self._stdio_ready:
393-
self._loop.call_soon(self._protocol.pipe_data_received, fd, data)
396+
self._loop.call_soon(self._protocol.pipe_data_received, fd, data,
397+
context=self.context)
394398
else:
395399
self._pending_calls.append((_CALL_PIPE_DATA_RECEIVED, fd, data))
396400

@@ -517,6 +521,7 @@ cdef class UVProcessTransport(UVProcess):
517521

518522
cdef _call_connection_made(self, waiter):
519523
try:
524+
# we're always called in the right context, so just call the user's
520525
self._protocol.connection_made(self)
521526
except (KeyboardInterrupt, SystemExit):
522527
raise
@@ -556,7 +561,9 @@ cdef class UVProcessTransport(UVProcess):
556561
self._finished = 1
557562

558563
if self._stdio_ready:
559-
self._loop.call_soon(self._protocol.connection_lost, None)
564+
# copy self.context for simplicity
565+
self._loop.call_soon(self._protocol.connection_lost, None,
566+
context=self.context)
560567
else:
561568
self._pending_calls.append((_CALL_CONNECTION_LOST, None, None))
562569

@@ -572,6 +579,7 @@ cdef class UVProcessTransport(UVProcess):
572579
new_MethodHandle1(self._loop,
573580
"UVProcessTransport._call_connection_made",
574581
<method1_t>self._call_connection_made,
582+
None, # means to copy the current context
575583
self, waiter))
576584

577585
@staticmethod
@@ -598,6 +606,8 @@ cdef class UVProcessTransport(UVProcess):
598606
if handle._init_futs:
599607
handle._stdio_ready = 0
600608
init_fut = aio_gather(*handle._init_futs)
609+
# add_done_callback will copy the current context and run the
610+
# callback within the context
601611
init_fut.add_done_callback(
602612
ft_partial(handle.__stdio_inited, waiter))
603613
else:
@@ -606,6 +616,7 @@ cdef class UVProcessTransport(UVProcess):
606616
new_MethodHandle1(loop,
607617
"UVProcessTransport._call_connection_made",
608618
<method1_t>handle._call_connection_made,
619+
None, # means to copy the current context
609620
handle, waiter))
610621

611622
return handle

uvloop/handles/stream.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ cdef class UVStream(UVBaseTransport):
1919
# All "inline" methods are final
2020

2121
cdef inline _init(self, Loop loop, object protocol, Server server,
22-
object waiter)
22+
object waiter, object context)
2323

2424
cdef inline _exec_write(self)
2525

uvloop/handles/stream.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ cdef class UVStream(UVBaseTransport):
612612
except AttributeError:
613613
keep_open = False
614614
else:
615-
keep_open = meth()
615+
keep_open = self.context.run(meth)
616616

617617
if keep_open:
618618
# We're keeping the connection open so the
@@ -631,8 +631,8 @@ cdef class UVStream(UVBaseTransport):
631631
self._shutdown()
632632

633633
cdef inline _init(self, Loop loop, object protocol, Server server,
634-
object waiter):
635-
634+
object waiter, object context):
635+
self.context = context
636636
self._set_protocol(protocol)
637637
self._start_init(loop)
638638

@@ -826,7 +826,7 @@ cdef inline void __uv_stream_on_read_impl(uv.uv_stream_t* stream,
826826
if UVLOOP_DEBUG:
827827
loop._debug_stream_read_cb_total += 1
828828

829-
sc._protocol_data_received(loop._recv_buffer[:nread])
829+
sc.context.run(sc._protocol_data_received, loop._recv_buffer[:nread])
830830
except BaseException as exc:
831831
if UVLOOP_DEBUG:
832832
loop._debug_stream_read_cb_errors_total += 1
@@ -911,7 +911,7 @@ cdef void __uv_stream_buffered_alloc(uv.uv_handle_t* stream,
911911

912912
sc._read_pybuf_acquired = 0
913913
try:
914-
buf = sc._protocol_get_buffer(suggested_size)
914+
buf = sc.context.run(sc._protocol_get_buffer, suggested_size)
915915
PyObject_GetBuffer(buf, pybuf, PyBUF_WRITABLE)
916916
got_buf = 1
917917
except BaseException as exc:
@@ -976,7 +976,7 @@ cdef void __uv_stream_buffered_on_read(uv.uv_stream_t* stream,
976976
if UVLOOP_DEBUG:
977977
loop._debug_stream_read_cb_total += 1
978978

979-
sc._protocol_buffer_updated(nread)
979+
sc.context.run(sc._protocol_buffer_updated, nread)
980980
except BaseException as exc:
981981
if UVLOOP_DEBUG:
982982
loop._debug_stream_read_cb_errors_total += 1

0 commit comments

Comments
 (0)