Skip to content

Commit c31e7c4

Browse files
committed
Pre-empt current task before running handle, allowing unpatched tasks, fixes #80
1 parent 35618de commit c31e7c4

File tree

1 file changed

+14
-38
lines changed

1 file changed

+14
-38
lines changed

nest_asyncio.py

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def apply(loop=None):
1313
"""Patch asyncio to make its event loop reentrant."""
1414
_patch_asyncio()
1515
_patch_policy()
16-
_patch_task()
1716
_patch_tornado()
1817

1918
loop = loop or asyncio.get_event_loop()
@@ -126,9 +125,20 @@ def _run_once(self):
126125
break
127126
handle = ready.popleft()
128127
if not handle._cancelled:
129-
handle._run()
128+
# preempt the current task so that that checks in
129+
# Task.__step do not raise
130+
curr_task = curr_tasks.pop(self, None)
131+
132+
try:
133+
handle._run()
134+
finally:
135+
# restore the current task
136+
if curr_task is not None:
137+
curr_tasks[self] = curr_task
138+
130139
handle = None
131140

141+
132142
@contextmanager
133143
def manage_run(self):
134144
"""Set up the loop for running."""
@@ -193,45 +203,11 @@ def _check_running(self):
193203
os.name == 'nt' and issubclass(cls, asyncio.ProactorEventLoop))
194204
if sys.version_info < (3, 7, 0):
195205
cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper
206+
curr_tasks = asyncio.tasks._current_tasks \
207+
if sys.version_info >= (3, 7, 0) else asyncio.Task._current_tasks
196208
cls._nest_patched = True
197209

198210

199-
def _patch_task():
200-
"""Patch the Task's step and enter/leave methods to make it reentrant."""
201-
202-
def step(task, exc=None):
203-
curr_task = curr_tasks.get(task._loop)
204-
try:
205-
step_orig(task, exc)
206-
finally:
207-
if curr_task is None:
208-
curr_tasks.pop(task._loop, None)
209-
else:
210-
curr_tasks[task._loop] = curr_task
211-
212-
Task = asyncio.Task
213-
if hasattr(Task, '_nest_patched'):
214-
return
215-
if sys.version_info >= (3, 7, 0):
216-
217-
def enter_task(loop, task):
218-
curr_tasks[loop] = task
219-
220-
def leave_task(loop, task):
221-
curr_tasks.pop(loop, None)
222-
223-
asyncio.tasks._enter_task = enter_task
224-
asyncio.tasks._leave_task = leave_task
225-
curr_tasks = asyncio.tasks._current_tasks
226-
step_orig = Task._Task__step
227-
Task._Task__step = step
228-
else:
229-
curr_tasks = Task._current_tasks
230-
step_orig = Task._step
231-
Task._step = step
232-
Task._nest_patched = True
233-
234-
235211
def _patch_tornado():
236212
"""
237213
If tornado is imported before nest_asyncio, make tornado aware of

0 commit comments

Comments
 (0)