-
-
Notifications
You must be signed in to change notification settings - Fork 184
Add asyncio.run() and asyncio.run_forever() functions. #465
base: master
Are you sure you want to change the base?
Changes from 6 commits
db2fe1d
9acdceb
f24ff30
fa721ee
b8b0fa0
3c90364
7e67b48
275072a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| """asyncio.run() and asyncio.run_forever() functions.""" | ||
|
|
||
| __all__ = ['run', 'run_forever'] | ||
|
|
||
| import inspect | ||
| import threading | ||
|
|
||
| from . import coroutines | ||
| from . import events | ||
|
|
||
|
|
||
| def _cleanup(loop): | ||
| try: | ||
| # `shutdown_asyncgens` was added in Python 3.6; not all | ||
| # event loops might support it. | ||
| shutdown_asyncgens = loop.shutdown_asyncgens | ||
| except AttributeError: | ||
| pass | ||
| else: | ||
| loop.run_until_complete(shutdown_asyncgens()) | ||
| finally: | ||
| events.set_event_loop(None) | ||
| loop.close() | ||
|
|
||
|
|
||
| def run(main, *, debug=False): | ||
| """Run a coroutine. | ||
|
|
||
| This function runs the passed coroutine, taking care of | ||
| managing the asyncio event loop and finalizing asynchronous | ||
| generators. | ||
|
|
||
| This function must be called from the main thread, and it | ||
| cannot be called when another asyncio event loop is running. | ||
|
|
||
| If debug is True, the event loop will be run in debug mode. | ||
|
|
||
| This function should be used as a main entry point for | ||
| asyncio programs, and should not be used to call asynchronous | ||
| APIs. | ||
|
|
||
| Example:: | ||
|
|
||
| async def main(): | ||
| await asyncio.sleep(1) | ||
| print('hello') | ||
|
|
||
| asyncio.run(main()) | ||
| """ | ||
| if events._get_running_loop() is not None: | ||
| raise RuntimeError( | ||
| "asyncio.run() cannot be called from a running event loop") | ||
| if not isinstance(threading.current_thread(), threading._MainThread): | ||
| raise RuntimeError( | ||
| "asyncio.run() must be called from the main thread") | ||
| if not coroutines.iscoroutine(main): | ||
| raise ValueError("a coroutine was expected, got {!r}".format(main)) | ||
|
|
||
| loop = events.new_event_loop() | ||
| try: | ||
| events.set_event_loop(loop) | ||
|
|
||
| if debug: | ||
| loop.set_debug(True) | ||
|
|
||
| return loop.run_until_complete(main) | ||
| finally: | ||
| _cleanup(loop) | ||
|
|
||
|
|
||
| def run_forever(main, *, debug=False): | ||
| """Run asyncio loop. | ||
|
|
||
| main must be an asynchronous generator with one yield, separating | ||
| program initialization from cleanup logic. | ||
|
|
||
| If debug is True, the event loop will be run in debug mode. | ||
|
|
||
| This function should be used as a main entry point for | ||
| asyncio programs, and should not be used to call asynchronous | ||
| APIs. | ||
|
|
||
| Example: | ||
|
|
||
| async def main(): | ||
| server = await asyncio.start_server(...) | ||
| try: | ||
| yield # <- Let event loop run forever. | ||
| except KeyboardInterrupt: | ||
| print('^C received; exiting.') | ||
| finally: | ||
| server.close() | ||
| await server.wait_closed() | ||
|
|
||
| asyncio.run_forever(main()) | ||
| """ | ||
| if not hasattr(inspect, 'isasyncgen'): | ||
| raise NotImplementedError | ||
|
|
||
| if events._get_running_loop() is not None: | ||
| raise RuntimeError( | ||
| "asyncio.run_forever() cannot be called from a running event loop") | ||
| if not isinstance(threading.current_thread(), threading._MainThread): | ||
| raise RuntimeError( | ||
| "asyncio.run() must be called from the main thread") | ||
|
||
| if not inspect.isasyncgen(main): | ||
| raise ValueError( | ||
| "an asynchronous generator was expected, got {!r}".format(main)) | ||
|
|
||
| loop = events.new_event_loop() | ||
| try: | ||
| events.set_event_loop(loop) | ||
| if debug: | ||
| loop.set_debug(True) | ||
|
|
||
| ret = None | ||
| try: | ||
| ret = loop.run_until_complete(main.asend(None)) | ||
| except StopAsyncIteration as ex: | ||
| return | ||
| if ret is not None: | ||
| raise RuntimeError("only empty yield is supported") | ||
|
||
|
|
||
| yielded_twice = False | ||
| try: | ||
| loop.run_forever() | ||
| except BaseException as ex: | ||
| try: | ||
| loop.run_until_complete(main.athrow(ex)) | ||
| except StopAsyncIteration as ex: | ||
| pass | ||
| else: | ||
| yielded_twice = True | ||
| else: | ||
| try: | ||
| loop.run_until_complete(main.asend(None)) | ||
| except StopAsyncIteration as ex: | ||
| pass | ||
| else: | ||
| yielded_twice = True | ||
|
|
||
| if yielded_twice: | ||
| raise RuntimeError("only one yield is supported") | ||
|
||
|
|
||
| finally: | ||
| _cleanup(loop) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| """Tests asyncio.run() and asyncio.run_forever().""" | ||
|
|
||
| import asyncio | ||
| import unittest | ||
| import sys | ||
|
|
||
| from unittest import mock | ||
|
|
||
|
|
||
| class TestPolicy(asyncio.AbstractEventLoopPolicy): | ||
|
|
||
| def __init__(self, loop_factory): | ||
| self.loop_factory = loop_factory | ||
| self.loop = None | ||
|
|
||
| def get_event_loop(self): | ||
| # shouldn't ever be called by asyncio.run() | ||
| # or asyncio.run_forever() | ||
| raise RuntimeError | ||
|
|
||
| def new_event_loop(self): | ||
| return self.loop_factory() | ||
|
|
||
| def set_event_loop(self, loop): | ||
| if loop is not None: | ||
| # we want to check if the loop is closed | ||
| # in BaseTest.tearDown | ||
| self.loop = loop | ||
|
|
||
|
|
||
| class BaseTest(unittest.TestCase): | ||
|
|
||
| def new_loop(self): | ||
| loop = asyncio.BaseEventLoop() | ||
| loop._process_events = mock.Mock() | ||
| loop._selector = mock.Mock() | ||
| loop._selector.select.return_value = () | ||
| loop.shutdown_ag_run = False | ||
|
|
||
| async def shutdown_asyncgens(): | ||
| loop.shutdown_ag_run = True | ||
| loop.shutdown_asyncgens = shutdown_asyncgens | ||
|
|
||
| return loop | ||
|
|
||
| def setUp(self): | ||
| super().setUp() | ||
|
|
||
| policy = TestPolicy(self.new_loop) | ||
| asyncio.set_event_loop_policy(policy) | ||
|
|
||
| def tearDown(self): | ||
| policy = asyncio.get_event_loop_policy() | ||
| if policy.loop is not None: | ||
| self.assertTrue(policy.loop.is_closed()) | ||
| self.assertTrue(policy.loop.shutdown_ag_run) | ||
|
|
||
| asyncio.set_event_loop_policy(None) | ||
| super().tearDown() | ||
|
|
||
|
|
||
| class RunTests(BaseTest): | ||
|
|
||
| def test_asyncio_run_return(self): | ||
| async def main(): | ||
| await asyncio.sleep(0) | ||
| return 42 | ||
|
|
||
| self.assertEqual(asyncio.run(main()), 42) | ||
|
|
||
| def test_asyncio_run_raises(self): | ||
| async def main(): | ||
| await asyncio.sleep(0) | ||
| raise ValueError('spam') | ||
|
|
||
| with self.assertRaisesRegex(ValueError, 'spam'): | ||
| asyncio.run(main()) | ||
|
|
||
| def test_asyncio_run_only_coro(self): | ||
| for o in {1, lambda: None}: | ||
| with self.subTest(obj=o), \ | ||
| self.assertRaisesRegex(ValueError, | ||
| 'a coroutine was expected'): | ||
| asyncio.run(o) | ||
|
|
||
| def test_asyncio_run_debug(self): | ||
| async def main(expected): | ||
| loop = asyncio.get_event_loop() | ||
| self.assertIs(loop.get_debug(), expected) | ||
|
|
||
| asyncio.run(main(False)) | ||
| asyncio.run(main(True), debug=True) | ||
|
|
||
| def test_asyncio_run_from_running_loop(self): | ||
| async def main(): | ||
| asyncio.run(main()) | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, | ||
| 'cannot be called from a running'): | ||
| asyncio.run(main()) | ||
|
|
||
|
|
||
| class RunForeverTests(BaseTest): | ||
|
|
||
| def stop_soon(self, *, exc=None): | ||
| loop = asyncio.get_event_loop() | ||
|
|
||
| if exc: | ||
| def throw(): | ||
| raise exc | ||
| loop.call_later(0.01, throw) | ||
| else: | ||
| loop.call_later(0.01, loop.stop) | ||
|
|
||
| def test_asyncio_run_forever_return(self): | ||
| async def main(): | ||
| if 0: | ||
| yield | ||
| return | ||
|
|
||
| self.assertIsNone(asyncio.run_forever(main())) | ||
|
|
||
| def test_asyncio_run_forever_non_none_yield(self): | ||
| async def main(): | ||
| yield 1 | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, 'only empty'): | ||
| self.assertIsNone(asyncio.run_forever(main())) | ||
|
|
||
| def test_asyncio_run_forever_raises_before_yield(self): | ||
| async def main(): | ||
| await asyncio.sleep(0) | ||
| raise ValueError('spam') | ||
| yield | ||
|
|
||
| with self.assertRaisesRegex(ValueError, 'spam'): | ||
| asyncio.run_forever(main()) | ||
|
|
||
| def test_asyncio_run_forever_raises_after_yield(self): | ||
| async def main(): | ||
| self.stop_soon() | ||
| yield | ||
| raise ValueError('spam') | ||
|
|
||
| with self.assertRaisesRegex(ValueError, 'spam'): | ||
| asyncio.run_forever(main()) | ||
|
|
||
| def test_asyncio_run_forever_two_yields(self): | ||
| async def main(): | ||
| self.stop_soon() | ||
| yield | ||
| yield | ||
| raise ValueError('spam') | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, 'only one yield'): | ||
| asyncio.run_forever(main()) | ||
|
|
||
| def test_asyncio_run_forever_only_ag(self): | ||
| async def coro(): | ||
| pass | ||
|
|
||
| for o in {1, lambda: None, coro()}: | ||
| with self.subTest(obj=o), \ | ||
| self.assertRaisesRegex(ValueError, | ||
| 'an asynchronous.*was expected'): | ||
| asyncio.run_forever(o) | ||
|
|
||
| def test_asyncio_run_forever_debug(self): | ||
| async def main(expected): | ||
| loop = asyncio.get_event_loop() | ||
| self.assertIs(loop.get_debug(), expected) | ||
| if 0: | ||
| yield | ||
|
|
||
| asyncio.run_forever(main(False)) | ||
| asyncio.run_forever(main(True), debug=True) | ||
|
|
||
| def test_asyncio_run_forever_from_running_loop(self): | ||
| async def main(): | ||
| asyncio.run_forever(main()) | ||
| if 0: | ||
| yield | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, | ||
| 'cannot be called from a running'): | ||
| asyncio.run_forever(main()) | ||
|
|
||
| def test_asyncio_run_forever_base_exception(self): | ||
| vi = sys.version_info | ||
| if vi[:2] != (3, 6) or vi.releaselevel == 'beta' and vi.serial < 4: | ||
| # See http://bugs.python.org/issue28721 for details. | ||
| raise unittest.SkipTest( | ||
| 'this test requires Python 3.6b4 or greater') | ||
|
|
||
| class MyExc(BaseException): | ||
| pass | ||
|
|
||
| async def main(): | ||
| self.stop_soon(exc=MyExc) | ||
| try: | ||
| yield | ||
| except MyExc: | ||
| pass | ||
|
|
||
| asyncio.run_forever(main()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe these two checks (no running loop and main thread) that appear here and in
run_forever()could be factored out to a helper function like you did for_cleanup(loop)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to customize the error message for each function, so I guess a little bit of copy/paste is fine.