|
2 | 2 |
|
3 | 3 | import asyncio |
4 | 4 | import unittest |
| 5 | +import functools |
5 | 6 |
|
6 | 7 | from contextvars import ContextVar |
7 | 8 | from unittest import mock |
8 | 9 |
|
9 | 10 |
|
10 | 11 | def tearDownModule(): |
11 | | - asyncio._set_event_loop_policy(None) |
| 12 | + asyncio.set_event_loop_policy(None) |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class ToThreadTests(unittest.IsolatedAsyncioTestCase): |
@@ -61,6 +62,41 @@ def get_ctx(): |
61 | 62 |
|
62 | 63 | self.assertEqual(result, 'parrot') |
63 | 64 |
|
| 65 | + @mock.patch('asyncio.base_events.BaseEventLoop.run_in_executor') |
| 66 | + async def test_to_thread_optimization_path(self, run_in_executor): |
| 67 | + # This test ensures that `to_thread` uses the correct execution path |
| 68 | + # based on whether the context is empty or not. |
| 69 | + |
| 70 | + # `to_thread` awaits the future returned by `run_in_executor`. |
| 71 | + # We need to provide a completed future as a return value for the mock. |
| 72 | + fut = asyncio.Future() |
| 73 | + fut.set_result(None) |
| 74 | + run_in_executor.return_value = fut |
| 75 | + |
| 76 | + def myfunc(): |
| 77 | + pass |
| 78 | + |
| 79 | + # Test with an empty context (optimized path) |
| 80 | + await asyncio.to_thread(myfunc) |
| 81 | + run_in_executor.assert_called_once() |
| 82 | + |
| 83 | + callback = run_in_executor.call_args.args[1] |
| 84 | + self.assertIsInstance(callback, functools.partial) |
| 85 | + self.assertIs(callback.func, myfunc) |
| 86 | + run_in_executor.reset_mock() |
| 87 | + |
| 88 | + # Test with a non-empty context (standard path) |
| 89 | + var = ContextVar('var') |
| 90 | + var.set('value') |
| 91 | + |
| 92 | + await asyncio.to_thread(myfunc) |
| 93 | + run_in_executor.assert_called_once() |
| 94 | + |
| 95 | + callback = run_in_executor.call_args.args[1] |
| 96 | + self.assertIsInstance(callback, functools.partial) |
| 97 | + self.assertIsNot(callback.func, myfunc) # Should be ctx.run |
| 98 | + self.assertIs(callback.args[0], myfunc) |
| 99 | + |
64 | 100 |
|
65 | 101 | if __name__ == "__main__": |
66 | 102 | unittest.main() |
0 commit comments