Skip to content

Commit 1de1dbd

Browse files
committed
Rebased with master and added tests
Run_hook is now async and renamed util to test_util so it gets picked up by pytest.
1 parent b423119 commit 1de1dbd

File tree

6 files changed

+238
-57
lines changed

6 files changed

+238
-57
lines changed

docs/client.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ on both versions. Here the traitlet ``kernel_name`` helps simplify and
9696
maintain consistency: we can just run a notebook twice, specifying first
9797
"python2" and then "python3" as the kernel name.
9898

99+
In addition to the two above, we also support traitlets for hooks. They are as
100+
follows: ``on_execution_start``, ``on_cell_start``, ``on_cell_complete``,
101+
``on_cell_error``. These traitlets allow specifying a ``Callable`` function,
102+
which will run at certain points during the notebook execution and is executed asynchronously.
103+
``on_execution_start`` will run when the notebook client is kicked off.
104+
``on_cell_start`` will run right before each cell is executed.
105+
``on_cell_complete`` will run right after the cell is executed.
106+
``on_cell_error`` will run if there is an error in the cell.
107+
99108
Handling errors and exceptions
100109
------------------------------
101110

nbclient/client.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,18 @@
1414
from jupyter_client.client import KernelClient
1515
from nbformat import NotebookNode
1616
from nbformat.v4 import output_from_msg
17-
from traitlets import Any, Bool, Dict, Enum, Integer, List, Type, Unicode, default
17+
from traitlets import (
18+
Any,
19+
Bool,
20+
Callable,
21+
Dict,
22+
Enum,
23+
Integer,
24+
List,
25+
Type,
26+
Unicode,
27+
default,
28+
)
1829
from traitlets.config.configurable import LoggingConfigurable
1930

2031
from .exceptions import (
@@ -25,7 +36,7 @@
2536
DeadKernelError,
2637
)
2738
from .output_widget import OutputWidget
28-
from .util import ensure_async, run_sync, run_hook
39+
from .util import ensure_async, run_hook, run_sync
2940

3041

3142
def timestamp() -> str:
@@ -245,43 +256,50 @@ class NotebookClient(LoggingConfigurable):
245256

246257
kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.')
247258

248-
on_execution_start: t.Optional[t.Callable] = Any(
259+
on_execution_start: t.Optional[t.Callable] = Callable(
249260
default_value=None,
250261
allow_none=True,
251-
help=dedent("""
252-
Called after the kernel manager and kernel client are setup, and cells
253-
are about to execute.
254-
Called with kwargs `kernel_id`.
255-
"""),
262+
help=dedent(
263+
"""
264+
Called after the kernel manager and kernel client are setup, and cells
265+
are about to execute.
266+
"""
267+
),
256268
).tag(config=True)
257269

258-
on_cell_start: t.Optional[t.Callable] = Any(
270+
on_cell_start: t.Optional[t.Callable] = Callable(
259271
default_value=None,
260272
allow_none=True,
261-
help=dedent("""
262-
A callable which executes before a cell is executed.
263-
Called with kwargs `cell`, and `cell_index`.
264-
"""),
273+
help=dedent(
274+
"""
275+
A callable which executes before a cell is executed.
276+
Called with kwargs `cell` and `cell_index`.
277+
"""
278+
),
265279
).tag(config=True)
266280

267-
on_cell_complete: t.Optional[t.Callable] = Any(
281+
on_cell_complete: t.Optional[t.Callable] = Callable(
268282
default_value=None,
269283
allow_none=True,
270-
help=dedent("""
271-
A callable which executes after a cell execution is complete. It is
272-
called even when a cell results in a failure.
273-
Called with kwargs `cell`, and `cell_index`.
274-
"""),
284+
help=dedent(
285+
"""
286+
A callable which executes after a cell execution is complete. It is
287+
called even when a cell results in a failure.
288+
Called with kwargs `cell` and `cell_index`.
289+
"""
290+
),
275291
).tag(config=True)
276292

277-
on_cell_error: t.Optional[t.Callable] = Any(
293+
on_cell_error: t.Optional[t.Callable] = Callable(
278294
default_value=None,
279295
allow_none=True,
280-
help=dedent("""
281-
A callable which executes when a cell execution results in an error.
282-
This is executed even if errors are suppressed with `cell_allows_errors`.
283-
Called with kwargs `cell`, and `cell_index`.
284-
"""),
296+
help=dedent(
297+
"""
298+
A callable which executes when a cell execution results in an error.
299+
This is executed even if errors are suppressed with `cell_allows_errors`.
300+
Called with kwargs `cell` and `cell_index`.
301+
"""
302+
),
285303
).tag(config=True)
286304

287305
@default('kernel_manager_class')
@@ -465,7 +483,7 @@ async def async_start_new_kernel_client(self) -> KernelClient:
465483
await self._async_cleanup_kernel()
466484
raise
467485
self.kc.allow_stdin = False
468-
run_hook(sself.on_execution_start)
486+
await run_hook(self.on_execution_start)
469487
return self.kc
470488

471489
start_new_kernel_client = run_sync(async_start_new_kernel_client)
@@ -769,11 +787,9 @@ def _passed_deadline(self, deadline: int) -> bool:
769787
return True
770788
return False
771789

772-
def _check_raise_for_error(
773-
self,
774-
cell: NotebookNode,
775-
cell_index: int,
776-
exec_reply: t.Optional[t.Dict]) -> None:
790+
async def _check_raise_for_error(
791+
self, cell: NotebookNode, cell_index: int, exec_reply: t.Optional[t.Dict]
792+
) -> None:
777793

778794
if exec_reply is None:
779795
return None
@@ -787,11 +803,9 @@ def _check_raise_for_error(
787803
or exec_reply_content.get('ename') in self.allow_error_names
788804
or "raises-exception" in cell.metadata.get("tags", [])
789805
)
790-
791-
if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
792-
run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
793-
if self.force_raise_errors or not cell_allows_errors:
794-
raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
806+
await run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
807+
if not cell_allows_errors:
808+
raise CellExecutionError.from_cell_and_msg(cell, exec_reply_content)
795809

796810
async def async_execute_cell(
797811
self,
@@ -851,13 +865,13 @@ async def async_execute_cell(
851865
self.allow_errors or "raises-exception" in cell.metadata.get("tags", [])
852866
)
853867

854-
run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
868+
await run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
855869
parent_msg_id = await ensure_async(
856870
self.kc.execute(
857871
cell.source, store_history=store_history, stop_on_error=not cell_allows_errors
858872
)
859873
)
860-
run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
874+
await run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
861875
# We launched a code cell to execute
862876
self.code_cells_executed += 1
863877
exec_timeout = self._get_timeout(cell)
@@ -891,7 +905,7 @@ async def async_execute_cell(
891905

892906
if execution_count:
893907
cell['execution_count'] = execution_count
894-
self._check_raise_for_error(cell, cell_index, exec_reply)
908+
await self._check_raise_for_error(cell, cell_index, exec_reply)
895909
self.nb['cells'][cell_index] = cell
896910
return cell
897911

nbclient/tests/test_client.py

Lines changed: 154 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,7 @@ def test_async_parallel_notebooks(capfd, tmpdir):
345345
res = notebook_resources()
346346

347347
with modified_env({"NBEXECUTE_TEST_PARALLEL_TMPDIR": str(tmpdir)}):
348-
tasks = [
349-
async_run_notebook(input_file.format(label=label), opts, res) for label in ("A", "B")
350-
]
351-
loop = asyncio.get_event_loop()
352-
loop.run_until_complete(asyncio.gather(*tasks))
348+
[async_run_notebook(input_file.format(label=label), opts, res) for label in ("A", "B")]
353349

354350
captured = capfd.readouterr()
355351
assert filter_messages_on_error_output(captured.err) == ""
@@ -370,9 +366,7 @@ def test_many_async_parallel_notebooks(capfd):
370366
# run once, to trigger creating the original context
371367
run_notebook(input_file, opts, res)
372368

373-
tasks = [async_run_notebook(input_file, opts, res) for i in range(4)]
374-
loop = asyncio.get_event_loop()
375-
loop.run_until_complete(asyncio.gather(*tasks))
369+
[async_run_notebook(input_file, opts, res) for i in range(4)]
376370

377371
captured = capfd.readouterr()
378372
assert filter_messages_on_error_output(captured.err) == ""
@@ -741,6 +735,80 @@ def test_widgets(self):
741735
assert 'version_major' in wdata
742736
assert 'version_minor' in wdata
743737

738+
def test_execution_hook(self):
739+
filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')
740+
with open(filename) as f:
741+
input_nb = nbformat.read(f, 4)
742+
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
743+
executor = NotebookClient(
744+
input_nb,
745+
on_cell_start=hook1,
746+
on_cell_complete=hook2,
747+
on_cell_error=hook3,
748+
on_execution_start=hook4,
749+
)
750+
executor.execute()
751+
hook1.assert_called_once()
752+
hook2.assert_called_once()
753+
hook3.assert_not_called()
754+
hook4.assert_called_once()
755+
756+
def test_error_execution_hook_error(self):
757+
filename = os.path.join(current_dir, 'files', 'Error.ipynb')
758+
with open(filename) as f:
759+
input_nb = nbformat.read(f, 4)
760+
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
761+
executor = NotebookClient(
762+
input_nb,
763+
on_cell_start=hook1,
764+
on_cell_complete=hook2,
765+
on_cell_error=hook3,
766+
on_execution_start=hook4,
767+
)
768+
with pytest.raises(CellExecutionError):
769+
executor.execute()
770+
hook1.assert_called_once()
771+
hook2.assert_called_once()
772+
hook3.assert_called_once()
773+
hook4.assert_called_once()
774+
775+
def test_async_execution_hook(self):
776+
filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')
777+
with open(filename) as f:
778+
input_nb = nbformat.read(f, 4)
779+
hook1, hook2, hook3, hook4 = AsyncMock(), AsyncMock(), AsyncMock(), AsyncMock()
780+
executor = NotebookClient(
781+
input_nb,
782+
on_cell_start=hook1,
783+
on_cell_complete=hook2,
784+
on_cell_error=hook3,
785+
on_execution_start=hook4,
786+
)
787+
executor.execute()
788+
hook1.assert_called_once()
789+
hook2.assert_called_once()
790+
hook3.assert_not_called()
791+
hook4.assert_called_once()
792+
793+
def test_error_async_execution_hook(self):
794+
filename = os.path.join(current_dir, 'files', 'Error.ipynb')
795+
with open(filename) as f:
796+
input_nb = nbformat.read(f, 4)
797+
hook1, hook2, hook3, hook4 = AsyncMock(), AsyncMock(), AsyncMock(), AsyncMock()
798+
executor = NotebookClient(
799+
input_nb,
800+
on_cell_start=hook1,
801+
on_cell_complete=hook2,
802+
on_cell_error=hook3,
803+
on_execution_start=hook4,
804+
)
805+
with pytest.raises(CellExecutionError):
806+
executor.execute().execute()
807+
hook1.assert_called_once()
808+
hook2.assert_called_once()
809+
hook3.assert_called_once()
810+
hook4.assert_called_once()
811+
744812

745813
class TestRunCell(NBClientTestsBase):
746814
"""Contains test functions for NotebookClient.execute_cell"""
@@ -1524,3 +1592,81 @@ def test_no_source(self, executor, cell_mock, message_mock):
15241592
assert message_mock.call_count == 0
15251593
# Should also consume the message stream
15261594
assert cell_mock.outputs == []
1595+
1596+
@prepare_cell_mocks()
1597+
def test_cell_hooks(self, executor, cell_mock, message_mock):
1598+
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
1599+
executor.on_cell_start = hook1
1600+
executor.on_cell_complete = hook2
1601+
executor.on_cell_error = hook3
1602+
executor.on_execution_start = hook4
1603+
executor.execute_cell(cell_mock, 0)
1604+
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
1605+
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
1606+
hook3.assert_not_called()
1607+
hook4.assert_not_called()
1608+
1609+
@prepare_cell_mocks(
1610+
{
1611+
'msg_type': 'error',
1612+
'header': {'msg_type': 'error'},
1613+
'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']},
1614+
},
1615+
reply_msg={
1616+
'msg_type': 'execute_reply',
1617+
'header': {'msg_type': 'execute_reply'},
1618+
# ERROR
1619+
'content': {'status': 'error'},
1620+
},
1621+
)
1622+
def test_error_cell_hooks(self, executor, cell_mock, message_mock):
1623+
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
1624+
executor.on_cell_start = hook1
1625+
executor.on_cell_complete = hook2
1626+
executor.on_cell_error = hook3
1627+
executor.on_execution_start = hook4
1628+
with self.assertRaises(CellExecutionError):
1629+
executor.execute_cell(cell_mock, 0)
1630+
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
1631+
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
1632+
hook3.assert_called_once_with(cell=cell_mock, cell_index=0)
1633+
hook4.assert_not_called()
1634+
1635+
@prepare_cell_mocks()
1636+
def test_async_cell_hooks(self, executor, cell_mock, message_mock):
1637+
hook1, hook2, hook3, hook4 = AsyncMock(), AsyncMock(), AsyncMock(), AsyncMock()
1638+
executor.on_cell_start = hook1
1639+
executor.on_cell_complete = hook2
1640+
executor.on_cell_error = hook3
1641+
executor.on_execution_start = hook4
1642+
executor.execute_cell(cell_mock, 0)
1643+
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
1644+
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
1645+
hook3.assert_not_called()
1646+
hook4.assert_not_called()
1647+
1648+
@prepare_cell_mocks(
1649+
{
1650+
'msg_type': 'error',
1651+
'header': {'msg_type': 'error'},
1652+
'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']},
1653+
},
1654+
reply_msg={
1655+
'msg_type': 'execute_reply',
1656+
'header': {'msg_type': 'execute_reply'},
1657+
# ERROR
1658+
'content': {'status': 'error'},
1659+
},
1660+
)
1661+
def test_error_async_cell_hooks(self, executor, cell_mock, message_mock):
1662+
hook1, hook2, hook3, hook4 = AsyncMock(), AsyncMock(), AsyncMock(), AsyncMock()
1663+
executor.on_cell_start = hook1
1664+
executor.on_cell_complete = hook2
1665+
executor.on_cell_error = hook3
1666+
executor.on_execution_start = hook4
1667+
with self.assertRaises(CellExecutionError):
1668+
executor.execute_cell(cell_mock, 0)
1669+
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
1670+
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
1671+
hook3.assert_called_once_with(cell=cell_mock, cell_index=0)
1672+
hook4.assert_not_called()

nbclient/tests/util.py renamed to nbclient/tests/test_util.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import asyncio
2+
from unittest.mock import MagicMock
23

4+
import pytest
35
import tornado
46

5-
from nbclient.util import run_sync
7+
from nbclient.util import run_hook, run_sync
68

79

810
@run_sync
@@ -55,3 +57,17 @@ async def run():
5557
assert some_sync_function() == 42
5658

5759
ioloop.run_sync(run)
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_run_hook_sync():
64+
some_sync_function = MagicMock()
65+
await run_hook(some_sync_function)
66+
assert some_sync_function.call_count == 1
67+
68+
69+
@pytest.mark.asyncio
70+
async def test_run_hook_async():
71+
hook = MagicMock(return_value=some_async_function())
72+
await run_hook(hook)
73+
assert hook.call_count == 1

0 commit comments

Comments
 (0)