Skip to content

Commit a5bbe2a

Browse files
committed
BugFix: nested asyncio loops failing
1 parent 92bbb7c commit a5bbe2a

File tree

2 files changed

+184
-33
lines changed

2 files changed

+184
-33
lines changed

coc/http.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,44 +30,19 @@
3030
import logging
3131
import aiohttp
3232
import asyncio
33-
import signal
34-
import sys
35-
import functools
3633

3734
from urllib.parse import urlencode
3835
from itertools import cycle
3936
from datetime import datetime
4037
from collections import deque
4138

39+
import coc.nest_asyncio
4240
from .errors import HTTPException, Maitenance, NotFound, InvalidArgument, Forbidden, InvalidCredentials
4341

4442
log = logging.getLogger(__name__)
4543
KEY_MINIMUM, KEY_MAXIMUM = 1, 10
4644

4745

48-
def timeout(seconds, error_message='Client timed out.'):
49-
def decorated(func):
50-
def _handle_timeout(signum, frame):
51-
raise TimeoutError(error_message)
52-
53-
def wrapper(*args, **kwargs):
54-
if sys.platform == 'win32':
55-
# TODO: Fix for windows
56-
# for now just return function and ignore the problem
57-
return func(*args, **kwargs)
58-
59-
signal.signal(signal.SIGALRM, _handle_timeout)
60-
signal.alarm(seconds)
61-
try:
62-
result = func(*args, **kwargs)
63-
finally:
64-
signal.alarm(0)
65-
return result
66-
67-
return functools.wraps(func)(wrapper)
68-
return decorated
69-
70-
7146
async def json_or_text(response):
7247
try:
7348
ret = await response.json()
@@ -137,6 +112,7 @@ class HTTPClient:
137112
def __init__(self, client, loop, email, password,
138113
key_names, key_count, throttle_limit):
139114
self.client = client
115+
coc.nest_asyncio.apply(loop)
140116
self.loop = loop
141117
self.email = email
142118
self.password = password
@@ -149,7 +125,7 @@ def __init__(self, client, loop, email, password,
149125
self.__lock = asyncio.Semaphore(per_second)
150126
self.__throttle = Throttler(per_second, loop=self.loop)
151127

152-
asyncio.ensure_future(self.get_keys())
128+
loop.run_until_complete(self.get_keys())
153129

154130
async def get_keys(self):
155131
self.__session = aiohttp.ClientSession(loop=self.loop)
@@ -183,18 +159,12 @@ async def get_keys(self):
183159
async def close(self):
184160
if self.__session:
185161
await self.__session.close()
186-
187-
@timeout(60, "Client timed out while attempting to establish a connection to the Developer Portal")
188-
async def ensure_logged_in(self):
189-
while not hasattr(self, 'keys'):
190-
await asyncio.sleep(0.1)
191162

192163
async def request(self, route, **kwargs):
193164
method = route.method
194165
url = route.url
195166

196167
if 'headers' not in kwargs:
197-
await self.ensure_logged_in()
198168
key = next(self.keys)
199169

200170
headers = {

coc/nest_asyncio.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import sys
2+
import asyncio
3+
import heapq
4+
5+
6+
def apply(loop=None):
7+
"""
8+
Patch asyncio to make its event loop reentrent.
9+
"""
10+
loop = loop or asyncio.get_event_loop()
11+
if not isinstance(loop, asyncio.BaseEventLoop):
12+
raise ValueError('Can\'t patch loop of type %s' % type(loop))
13+
if hasattr(loop, '_run_until_complete_orig'):
14+
# already patched
15+
return
16+
_patch_asyncio()
17+
_patch_loop(loop)
18+
_patch_task()
19+
_patch_handle()
20+
21+
22+
def _patch_asyncio():
23+
"""
24+
Patch asyncio module to use pure Python tasks and futures,
25+
use module level _current_tasks, all_tasks and patch run method.
26+
"""
27+
def run(future, *, debug=False):
28+
loop = asyncio.get_event_loop()
29+
run_orig = asyncio._run_orig # noqa
30+
if run_orig and not loop.is_running():
31+
return run_orig(future, debug=debug)
32+
else:
33+
loop.set_debug(debug)
34+
return loop.run_until_complete(future)
35+
36+
if sys.version_info >= (3, 6, 0):
37+
asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = \
38+
asyncio.tasks._PyTask
39+
asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = \
40+
asyncio.futures._PyFuture
41+
if sys.version_info < (3, 7, 0):
42+
asyncio.tasks._current_tasks = asyncio.tasks.Task._current_tasks # noqa
43+
asyncio.all_tasks = asyncio.tasks.Task.all_tasks # noqa
44+
if not hasattr(asyncio, '_run_orig'):
45+
asyncio._run_orig = getattr(asyncio, 'run', None)
46+
asyncio.run = run
47+
48+
49+
def _patch_loop(loop):
50+
"""
51+
Patch loop to make it reentrent.
52+
"""
53+
def run_until_complete(self, future):
54+
if self.is_running():
55+
self._check_closed()
56+
f = asyncio.ensure_future(future)
57+
if f is not future:
58+
f._log_destroy_pending = False
59+
while not f.done():
60+
run_once(self)
61+
return f.result()
62+
else:
63+
return self._run_until_complete_orig(future)
64+
65+
bogus_handle = asyncio.Handle(None, None, loop)
66+
bogus_handle.cancel()
67+
68+
def run_once(self):
69+
ready = self._ready
70+
scheduled = self._scheduled
71+
72+
# remove bogus handles to get more efficient timeout
73+
while ready and ready[0] is bogus_handle:
74+
ready.popleft()
75+
nready = len(ready)
76+
77+
while scheduled and scheduled[0]._cancelled:
78+
self._timer_cancelled_count -= 1
79+
handle = heapq.heappop(scheduled)
80+
handle._scheduled = False
81+
82+
timeout = None
83+
if ready or self._stopping:
84+
timeout = 0
85+
elif scheduled:
86+
when = scheduled[0]._when
87+
timeout = max(0, when - self.time())
88+
89+
event_list = self._selector.select(timeout)
90+
self._process_events(event_list)
91+
92+
end_time = self.time() + self._clock_resolution
93+
while scheduled:
94+
handle = scheduled[0]
95+
if handle._when >= end_time:
96+
break
97+
handle = heapq.heappop(scheduled)
98+
handle._scheduled = False
99+
ready.append(handle)
100+
101+
self._nesting_level += 1
102+
ntodo = len(ready)
103+
for _ in range(ntodo):
104+
if not ready:
105+
break
106+
handle = ready.popleft()
107+
if handle._cancelled:
108+
continue
109+
handle._run()
110+
handle = None
111+
self._nesting_level -= 1
112+
113+
# add bogus handles to keep loop._run_once happy
114+
if nready and self._nesting_level == 0:
115+
ready.extendleft([bogus_handle] * nready)
116+
117+
cls = loop.__class__
118+
cls._run_until_complete_orig = cls.run_until_complete
119+
cls.run_until_complete = run_until_complete
120+
cls._nesting_level = 0
121+
122+
123+
def _patch_task():
124+
"""
125+
Patch the Task's step and enter/leave methods to make it reentrant.
126+
"""
127+
def step(task, exc=None):
128+
curr_task = curr_tasks.get(task._loop)
129+
try:
130+
step_orig(task, exc)
131+
finally:
132+
if curr_task is None:
133+
curr_tasks.pop(task._loop, None)
134+
else:
135+
curr_tasks[task._loop] = curr_task
136+
137+
Task = asyncio.Task
138+
if sys.version_info >= (3, 7, 0):
139+
140+
def enter_task(loop, task):
141+
curr_tasks[loop] = task
142+
143+
def leave_task(loop, task):
144+
del curr_tasks[loop]
145+
146+
asyncio.tasks._enter_task = enter_task
147+
asyncio.tasks._leave_task = leave_task
148+
curr_tasks = asyncio.tasks._current_tasks
149+
step_orig = Task._Task__step
150+
Task._Task__step = step
151+
else:
152+
curr_tasks = Task._current_tasks
153+
step_orig = Task._step
154+
Task._step = step
155+
156+
157+
def _patch_handle():
158+
"""
159+
Patch Handle to allow recursive calls.
160+
"""
161+
def run(self):
162+
try:
163+
ctx = self._context.copy()
164+
ctx.run(self._callback, *self._args)
165+
except Exception as exc:
166+
cb = format_helpers._format_callback_source(
167+
self._callback, self._args)
168+
msg = 'Exception in callback {}'.format(cb)
169+
context = {
170+
'message': msg,
171+
'exception': exc,
172+
'handle': self,
173+
}
174+
if self._source_traceback:
175+
context['source_traceback'] = self._source_traceback
176+
self._loop.call_exception_handler(context)
177+
self = None
178+
179+
if sys.version_info >= (3, 7, 0):
180+
from asyncio import format_helpers
181+
asyncio.events.Handle._run = run

0 commit comments

Comments
 (0)