9
9
import warnings
10
10
from base64 import b64decode , b64encode
11
11
from queue import Empty
12
+ from typing import Any
12
13
from unittest .mock import MagicMock , Mock
13
14
14
15
import nbformat
@@ -78,11 +79,15 @@ class AsyncMock(Mock):
78
79
pass
79
80
80
81
81
- def make_async (mock_value ):
82
- async def _ ():
83
- return mock_value
84
-
85
- return _ ()
82
+ def make_future (obj : Any ) -> asyncio .Future :
83
+ try :
84
+ loop = asyncio .get_running_loop ()
85
+ except RuntimeError :
86
+ loop = asyncio .new_event_loop ()
87
+ asyncio .set_event_loop (loop )
88
+ future : asyncio .Future = asyncio .Future (loop = loop )
89
+ future .set_result (obj )
90
+ return future
86
91
87
92
88
93
def normalize_base64 (b64_text ):
@@ -169,7 +174,7 @@ def shell_channel_message_mock():
169
174
# Return the message generator for
170
175
# self.kc.shell_channel.get_msg => {'parent_header': {'msg_id': parent_id}}
171
176
return AsyncMock (
172
- return_value = make_async (
177
+ return_value = make_future (
173
178
NBClientTestsBase .merge_dicts (
174
179
{
175
180
'parent_header' : {'msg_id' : parent_id },
@@ -186,7 +191,7 @@ def iopub_messages_mock():
186
191
return AsyncMock (
187
192
side_effect = [
188
193
# Default the parent_header so mocks don't need to include this
189
- make_async (
194
+ make_future (
190
195
NBClientTestsBase .merge_dicts ({'parent_header' : {'msg_id' : parent_id }}, msg )
191
196
)
192
197
for msg in messages
@@ -215,7 +220,7 @@ def test_mock_wrapper(self):
215
220
iopub_channel = MagicMock (get_msg = message_mock ),
216
221
shell_channel = MagicMock (get_msg = shell_channel_message_mock ()),
217
222
execute = MagicMock (return_value = parent_id ),
218
- is_alive = MagicMock (return_value = make_async (True )),
223
+ is_alive = MagicMock (return_value = make_future (True )),
219
224
)
220
225
executor .parent_id = parent_id
221
226
return func (self , executor , cell_mock , message_mock )
@@ -387,11 +392,15 @@ def test_async_parallel_notebooks(capfd, tmpdir):
387
392
res = notebook_resources ()
388
393
389
394
with modified_env ({"NBEXECUTE_TEST_PARALLEL_TMPDIR" : str (tmpdir )}):
390
- tasks = [
391
- async_run_notebook (input_file .format (label = label ), opts , res ) for label in ("A" , "B" )
392
- ]
393
- loop = asyncio .get_event_loop ()
394
- loop .run_until_complete (asyncio .gather (* tasks ))
395
+
396
+ async def run_tasks ():
397
+ tasks = [
398
+ async_run_notebook (input_file .format (label = label ), opts , res )
399
+ for label in ("A" , "B" )
400
+ ]
401
+ await asyncio .gather (* tasks )
402
+
403
+ asyncio .run (run_tasks ())
395
404
396
405
captured = capfd .readouterr ()
397
406
assert filter_messages_on_error_output (captured .err ) == ""
@@ -412,9 +421,11 @@ def test_many_async_parallel_notebooks(capfd):
412
421
# run once, to trigger creating the original context
413
422
run_notebook (input_file , opts , res )
414
423
415
- tasks = [async_run_notebook (input_file , opts , res ) for i in range (4 )]
416
- loop = asyncio .get_event_loop ()
417
- loop .run_until_complete (asyncio .gather (* tasks ))
424
+ async def run_tasks ():
425
+ tasks = [async_run_notebook (input_file , opts , res ) for i in range (4 )]
426
+ await asyncio .gather (* tasks )
427
+
428
+ asyncio .run (run_tasks ())
418
429
419
430
captured = capfd .readouterr ()
420
431
assert filter_messages_on_error_output (captured .err ) == ""
@@ -966,7 +977,7 @@ def message_seq(messages):
966
977
967
978
message_mock .side_effect = message_seq (list (message_mock .side_effect )[:- 1 ])
968
979
executor .kc .shell_channel .get_msg = Mock (
969
- return_value = make_async ({'parent_header' : {'msg_id' : executor .parent_id }})
980
+ return_value = make_future ({'parent_header' : {'msg_id' : executor .parent_id }})
970
981
)
971
982
executor .raise_on_iopub_timeout = True
972
983
0 commit comments