Skip to content

Commit a8cfab6

Browse files
Aniket Pansefacebook-github-bot
authored andcommitted
Add ability to set awaiters on coroutines and futures
Summary: Fixed a refleak in the original change. This is essentially a squashed version of two diffs: D67813375 and D68596487. Landing it together for safety. Reviewed By: aleivag Differential Revision: D68597387 fbshipit-source-id: 2bfe721a5cf74cd71a9782718da557d1fee9444a
1 parent ae66b18 commit a8cfab6

File tree

10 files changed

+927
-23
lines changed

10 files changed

+927
-23
lines changed

Include/cpython/genobject.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@
77
extern "C" {
88
#endif
99

10+
static inline void Ci_PyAwaitable_SetAwaiter(PyObject *receiver, PyObject *awaiter) {
11+
PyTypeObject *ty = Py_TYPE(receiver);
12+
if (!PyType_HasFeature(ty, Ci_TPFLAGS_HAVE_AM_EXTRA)) {
13+
return;
14+
}
15+
Ci_AsyncMethodsWithExtra *ame = (Ci_AsyncMethodsWithExtra *)ty->tp_as_async;
16+
if ((ame != NULL) && (ame->ame_setawaiter != NULL)) {
17+
ame->ame_setawaiter(receiver, awaiter);
18+
}
19+
}
20+
1021
/* --- Generators --------------------------------------------------------- */
1122

1223
/* _PyGenObject_HEAD defines the initial segment of generator

Lib/asyncio/tasks.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
'current_task', 'all_tasks',
99
'create_eager_task_factory', 'eager_task_factory',
1010
'_register_task', '_unregister_task', '_enter_task', '_leave_task',
11+
'get_async_stack',
1112
)
1213

1314
import concurrent.futures
@@ -16,6 +17,7 @@
1617
import inspect
1718
import itertools
1819
import types
20+
import sys
1921
import warnings
2022
import weakref
2123
from types import GenericAlias
@@ -732,6 +734,11 @@ def cancel(self, msg=None):
732734
self._cancel_requested = True
733735
return ret
734736

737+
def __set_awaiter__(self, awaiter):
738+
for child in self._children:
739+
if hasattr(child, "__set_awaiter__"):
740+
child.__set_awaiter__(awaiter)
741+
735742

736743
def gather(*coros_or_futures, return_exceptions=False):
737744
"""Return a future aggregating results from the given coroutines/futures.
@@ -956,6 +963,62 @@ def callback():
956963
return future
957964

958965

966+
def get_async_stack():
967+
"""Return the async call stack for the currently executing task as a list of
968+
frames, with the most recent frame last.
969+
The async call stack consists of the call stack for the currently executing
970+
task, if any, plus the call stack formed by the transitive set of coroutines/async
971+
generators awaiting the current task.
972+
Consider the following example, where T represents a task, C represents
973+
a coroutine, and A '->' B indicates A is awaiting B.
974+
T0 +---> T1
975+
| | |
976+
C0 | C2
977+
| | |
978+
v | v
979+
C1 | C3
980+
| |
981+
+-----|
982+
The await stack from C3 would be C3, C2, C1, C0. In contrast, the
983+
synchronous call stack while C3 is executing is only C3, C2.
984+
"""
985+
if not hasattr(sys, "_getframe"):
986+
return []
987+
988+
task = current_task()
989+
coro = task.get_coro()
990+
coro_frame = coro.cr_frame
991+
992+
# Get the active portion of the stack
993+
stack = []
994+
frame = sys._getframe().f_back
995+
while frame is not None:
996+
stack.append(frame)
997+
if frame is coro_frame:
998+
break
999+
frame = frame.f_back
1000+
assert frame is coro_frame
1001+
1002+
# Get the suspended portion of the stack
1003+
awaiter = coro.cr_awaiter
1004+
while awaiter is not None:
1005+
if hasattr(awaiter, "cr_frame"):
1006+
stack.append(awaiter.cr_frame)
1007+
awaiter = awaiter.cr_awaiter
1008+
elif hasattr(awaiter, "ag_frame"):
1009+
stack.append(awaiter.ag_frame)
1010+
awaiter = awaiter.ag_awaiter
1011+
else:
1012+
raise ValueError(f"Unexpected awaiter {awaiter}")
1013+
1014+
stack.reverse()
1015+
return stack
1016+
1017+
1018+
# WeakSet containing all alive tasks.
1019+
_all_tasks = weakref.WeakSet()
1020+
1021+
9591022
def create_eager_task_factory(custom_task_constructor):
9601023
"""Create a function suitable for use as a task factory on an event-loop.
9611024

Lib/test/test_asyncgen.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,5 +1903,30 @@ async def run():
19031903
self.loop.run_until_complete(run())
19041904

19051905

1906+
class AsyncGeneratorAwaiterTest(unittest.TestCase):
1907+
def setUp(self):
1908+
self.loop = asyncio.new_event_loop()
1909+
asyncio.set_event_loop(None)
1910+
1911+
def tearDown(self):
1912+
self.loop.close()
1913+
self.loop = None
1914+
asyncio.set_event_loop_policy(None)
1915+
1916+
def test_basic_await(self):
1917+
async def async_gen():
1918+
self.assertIs(agen_obj.ag_awaiter, awaiter_obj)
1919+
yield 10
1920+
1921+
async def awaiter(agen):
1922+
async for x in agen:
1923+
pass
1924+
1925+
agen_obj = async_gen()
1926+
awaiter_obj = awaiter(agen_obj)
1927+
self.assertIsNone(agen_obj.ag_awaiter)
1928+
self.loop.run_until_complete(awaiter_obj)
1929+
1930+
19061931
if __name__ == "__main__":
19071932
unittest.main()

Lib/test/test_asyncio/test_tasks.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2489,6 +2489,24 @@ def test_get_context(self):
24892489
finally:
24902490
loop.close()
24912491

2492+
def test_get_awaiter(self):
2493+
ctask = getattr(tasks, '_CTask', None)
2494+
if ctask is None or not issubclass(self.Task, ctask):
2495+
self.skipTest("Only subclasses of _CTask set cr_awaiter on wrapped coroutines")
2496+
2497+
async def coro():
2498+
self.assertIs(coro_obj.cr_awaiter, awaiter_obj)
2499+
return "ok"
2500+
2501+
async def awaiter(coro):
2502+
task = self.loop.create_task(coro)
2503+
return await task
2504+
2505+
coro_obj = coro()
2506+
awaiter_obj = awaiter(coro_obj)
2507+
self.assertIsNone(coro_obj.cr_awaiter)
2508+
self.assertEqual(self.loop.run_until_complete(awaiter_obj), "ok")
2509+
self.assertIsNone(coro_obj.cr_awaiter)
24922510

24932511
def add_subclass_tests(cls):
24942512
BaseTask = cls.Task
@@ -3237,6 +3255,22 @@ async def coro(s):
32373255
# NameError should not happen:
32383256
self.one_loop.call_exception_handler.assert_not_called()
32393257

3258+
def test_propagate_awaiter(self):
3259+
async def coro(idx):
3260+
self.assertIs(coro_objs[idx].cr_awaiter, awaiter_obj)
3261+
return "ok"
3262+
3263+
async def awaiter(coros):
3264+
tasks = [self.one_loop.create_task(c) for c in coros]
3265+
return await asyncio.gather(*tasks)
3266+
3267+
coro_objs = [coro(0), coro(1)]
3268+
awaiter_obj = awaiter(coro_objs)
3269+
self.assertIsNone(coro_objs[0].cr_awaiter)
3270+
self.assertIsNone(coro_objs[1].cr_awaiter)
3271+
self.assertEqual(self.one_loop.run_until_complete(awaiter_obj), ["ok", "ok"])
3272+
self.assertIsNone(coro_objs[0].cr_awaiter)
3273+
self.assertIsNone(coro_objs[1].cr_awaiter)
32403274

32413275
class RunCoroutineThreadsafeTests(test_utils.TestCase):
32423276
"""Test case for asyncio.run_coroutine_threadsafe."""
@@ -3449,5 +3483,57 @@ def tearDown(self):
34493483
super().tearDown()
34503484

34513485

3486+
3487+
class GetAsyncStackTests(test_utils.TestCase):
3488+
def setUp(self):
3489+
self.loop = asyncio.new_event_loop()
3490+
asyncio.set_event_loop(None)
3491+
3492+
def tearDown(self):
3493+
self.loop.close()
3494+
self.loop = None
3495+
asyncio.set_event_loop_policy(None)
3496+
3497+
def check_stack(self, frames, expected_funcs):
3498+
given = [f.f_code for f in frames]
3499+
expected = [f.__code__ for f in expected_funcs]
3500+
self.assertEqual(given, expected)
3501+
3502+
def test_single_task(self):
3503+
async def coro():
3504+
await coro2()
3505+
3506+
async def coro2():
3507+
stack = asyncio.get_async_stack()
3508+
self.check_stack(stack, [coro, coro2])
3509+
3510+
self.loop.run_until_complete(coro())
3511+
3512+
def test_cross_tasks(self):
3513+
async def coro():
3514+
t = asyncio.ensure_future(coro2())
3515+
await t
3516+
3517+
async def coro2():
3518+
t = asyncio.ensure_future(coro3())
3519+
await t
3520+
3521+
async def coro3():
3522+
stack = asyncio.get_async_stack()
3523+
self.check_stack(stack, [coro, coro2, coro3])
3524+
3525+
self.loop.run_until_complete(coro())
3526+
3527+
def test_cross_gather(self):
3528+
async def coro():
3529+
await asyncio.gather(coro2(), coro2())
3530+
3531+
async def coro2():
3532+
stack = asyncio.get_async_stack()
3533+
self.check_stack(stack, [coro, coro2])
3534+
3535+
self.loop.run_until_complete(coro())
3536+
3537+
34523538
if __name__ == '__main__':
34533539
unittest.main()

Lib/test/test_coroutines.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,5 +2478,63 @@ async def foo():
24782478
self.assertEqual(foo().send(None), 1)
24792479

24802480

2481+
2482+
class CoroutineAwaiterTest(unittest.TestCase):
2483+
def test_basic_await(self):
2484+
async def coro():
2485+
self.assertIs(coro_obj.cr_awaiter, awaiter_obj)
2486+
return "success"
2487+
2488+
async def awaiter():
2489+
return await coro_obj
2490+
2491+
coro_obj = coro()
2492+
awaiter_obj = awaiter()
2493+
self.assertIsNone(coro_obj.cr_awaiter)
2494+
self.assertEqual(run_async(awaiter_obj), ([], "success"))
2495+
2496+
class FakeFuture:
2497+
def __await__(self):
2498+
return iter(["future"])
2499+
2500+
def test_coro_outlives_awaiter(self):
2501+
async def coro():
2502+
await self.FakeFuture()
2503+
2504+
async def awaiter(cr):
2505+
await cr
2506+
2507+
coro_obj = coro()
2508+
self.assertIsNone(coro_obj.cr_awaiter)
2509+
awaiter_obj = awaiter(coro_obj)
2510+
self.assertIsNone(coro_obj.cr_awaiter)
2511+
2512+
v1 = awaiter_obj.send(None)
2513+
self.assertEqual(v1, "future")
2514+
self.assertIs(coro_obj.cr_awaiter, awaiter_obj)
2515+
2516+
awaiter_id = id(awaiter_obj)
2517+
del awaiter_obj
2518+
self.assertEqual(id(coro_obj.cr_awaiter), awaiter_id)
2519+
2520+
def test_async_gen_awaiter(self):
2521+
async def coro():
2522+
self.assertIs(coro_obj.cr_awaiter, agen)
2523+
await self.FakeFuture()
2524+
2525+
async def async_gen(cr):
2526+
await cr
2527+
yield "hi"
2528+
2529+
coro_obj = coro()
2530+
self.assertIsNone(coro_obj.cr_awaiter)
2531+
agen = async_gen(coro_obj)
2532+
self.assertIsNone(coro_obj.cr_awaiter)
2533+
2534+
v1 = agen.asend(None).send(None)
2535+
self.assertEqual(v1, "future")
2536+
2537+
2538+
24812539
if __name__=="__main__":
24822540
unittest.main()

Misc/ACKS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,7 @@ Noah Oxer
13561356
Joonas Paalasmaa
13571357
Yaroslav Pankovych
13581358
Martin Packman
1359+
Matt Page
13591360
Elisha Paine
13601361
Shriphani Palakodety
13611362
Julien Palard

0 commit comments

Comments
 (0)