Skip to content

Commit 8b54421

Browse files
committed
Rebased with master and added tests
Run_hook is now async and renamed util to test_util so it gets picked up by pytest.
1 parent 0f6a12f commit 8b54421

File tree

6 files changed

+237
-50
lines changed

6 files changed

+237
-50
lines changed

docs/client.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ on both versions. Here the traitlet ``kernel_name`` helps simplify and
9696
maintain consistency: we can just run a notebook twice, specifying first
9797
"python2" and then "python3" as the kernel name.
9898

99+
In addition to the two above, we also support traitlets for hooks. They are as
100+
follows: ``on_execution_start``, ``on_cell_start``, ``on_cell_complete``,
101+
``on_cell_error``. These traitlets allow specifying a ``Callable`` function,
102+
which will run at certain points during the notebook execution and is executed asynchronously.
103+
``on_execution_start`` will run when the notebook client is kicked off.
104+
``on_cell_start`` will run right before each cell is executed.
105+
``on_cell_complete`` will run right after the cell is executed.
106+
``on_cell_error`` will run if there is an error in the cell.
107+
99108
Handling errors and exceptions
100109
------------------------------
101110

nbclient/client.py

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,18 @@
1414
from jupyter_client.client import KernelClient
1515
from nbformat import NotebookNode
1616
from nbformat.v4 import output_from_msg
17-
from traitlets import Any, Bool, Dict, Enum, Integer, List, Type, Unicode, default
17+
from traitlets import (
18+
Any,
19+
Bool,
20+
Callable,
21+
Dict,
22+
Enum,
23+
Integer,
24+
List,
25+
Type,
26+
Unicode,
27+
default,
28+
)
1829
from traitlets.config.configurable import LoggingConfigurable
1930

2031
from .exceptions import (
@@ -25,7 +36,7 @@
2536
DeadKernelError,
2637
)
2738
from .output_widget import OutputWidget
28-
from .util import ensure_async, run_sync, run_hook
39+
from .util import ensure_async, run_hook, run_sync
2940

3041

3142
def timestamp() -> str:
@@ -245,43 +256,50 @@ class NotebookClient(LoggingConfigurable):
245256

246257
kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.')
247258

248-
on_execution_start: t.Optional[t.Callable] = Any(
259+
on_execution_start: t.Optional[t.Callable] = Callable(
249260
default_value=None,
250261
allow_none=True,
251-
help=dedent("""
262+
help=dedent(
263+
"""
252264
Called after the kernel manager and kernel client are setup, and cells
253265
are about to execute.
254-
Called with kwargs `kernel_id`.
255-
"""),
266+
"""
267+
),
256268
).tag(config=True)
257269

258-
on_cell_start: t.Optional[t.Callable] = Any(
270+
on_cell_start: t.Optional[t.Callable] = Callable(
259271
default_value=None,
260272
allow_none=True,
261-
help=dedent("""
262-
A callable which executes before a cell is executed.
263-
Called with kwargs `cell`, and `cell_index`.
264-
"""),
273+
help=dedent(
274+
"""
275+
A callable which executes before a cell is executed.
276+
Called with kwargs `cell` and `cell_index`.
277+
"""
278+
),
265279
).tag(config=True)
266280

267-
on_cell_complete: t.Optional[t.Callable] = Any(
281+
on_cell_complete: t.Optional[t.Callable] = Callable(
268282
default_value=None,
269283
allow_none=True,
270-
help=dedent("""
271-
A callable which executes after a cell execution is complete. It is
272-
called even when a cell results in a failure.
273-
Called with kwargs `cell`, and `cell_index`.
274-
"""),
284+
help=dedent(
285+
"""
286+
A callable which executes after a cell execution is complete. It is
287+
called even when a cell results in a failure.
288+
Called with kwargs `cell` and `cell_index`.
289+
"""
290+
),
275291
).tag(config=True)
276292

277-
on_cell_error: t.Optional[t.Callable] = Any(
293+
on_cell_error: t.Optional[t.Callable] = Callable(
278294
default_value=None,
279295
allow_none=True,
280-
help=dedent("""
281-
A callable which executes when a cell execution results in an error.
282-
This is executed even if errors are suppressed with `cell_allows_errors`.
283-
Called with kwargs `cell`, and `cell_index`.
284-
"""),
296+
help=dedent(
297+
"""
298+
A callable which executes when a cell execution results in an error.
299+
This is executed even if errors are suppressed with `cell_allows_errors`.
300+
Called with kwargs `cell` and `cell_index`.
301+
"""
302+
),
285303
).tag(config=True)
286304

287305
@default('kernel_manager_class')
@@ -465,7 +483,7 @@ async def async_start_new_kernel_client(self) -> KernelClient:
465483
await self._async_cleanup_kernel()
466484
raise
467485
self.kc.allow_stdin = False
468-
run_hook(sself.on_execution_start)
486+
await run_hook(self.on_execution_start)
469487
return self.kc
470488

471489
start_new_kernel_client = run_sync(async_start_new_kernel_client)
@@ -769,11 +787,9 @@ def _passed_deadline(self, deadline: int) -> bool:
769787
return True
770788
return False
771789

772-
def _check_raise_for_error(
773-
self,
774-
cell: NotebookNode,
775-
cell_index: int,
776-
exec_reply: t.Optional[t.Dict]) -> None:
790+
async def _check_raise_for_error(
791+
self, cell: NotebookNode, cell_index: int, exec_reply: t.Optional[t.Dict]
792+
) -> None:
777793

778794
if exec_reply is None:
779795
return None
@@ -787,11 +803,9 @@ def _check_raise_for_error(
787803
or exec_reply_content.get('ename') in self.allow_error_names
788804
or "raises-exception" in cell.metadata.get("tags", [])
789805
)
790-
791-
if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
792-
run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
793-
if self.force_raise_errors or not cell_allows_errors:
794-
raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
806+
await run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
807+
if not cell_allows_errors:
808+
raise CellExecutionError.from_cell_and_msg(cell, exec_reply_content)
795809

796810
async def async_execute_cell(
797811
self,
@@ -851,13 +865,13 @@ async def async_execute_cell(
851865
self.allow_errors or "raises-exception" in cell.metadata.get("tags", [])
852866
)
853867

854-
run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
868+
await run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
855869
parent_msg_id = await ensure_async(
856870
self.kc.execute(
857871
cell.source, store_history=store_history, stop_on_error=not cell_allows_errors
858872
)
859873
)
860-
run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
874+
await run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
861875
# We launched a code cell to execute
862876
self.code_cells_executed += 1
863877
exec_timeout = self._get_timeout(cell)
@@ -891,7 +905,7 @@ async def async_execute_cell(
891905

892906
if execution_count:
893907
cell['execution_count'] = execution_count
894-
self._check_raise_for_error(cell, cell_index, exec_reply)
908+
await self._check_raise_for_error(cell, cell_index, exec_reply)
895909
self.nb['cells'][cell_index] = cell
896910
return cell
897911

nbclient/tests/test_client.py

Lines changed: 155 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010
from base64 import b64decode, b64encode
1111
from queue import Empty
12+
from unittest.mock import AsyncMock as AMock
1213
from unittest.mock import MagicMock, Mock
1314

1415
import nbformat
@@ -345,11 +346,7 @@ def test_async_parallel_notebooks(capfd, tmpdir):
345346
res = notebook_resources()
346347

347348
with modified_env({"NBEXECUTE_TEST_PARALLEL_TMPDIR": str(tmpdir)}):
348-
tasks = [
349-
async_run_notebook(input_file.format(label=label), opts, res) for label in ("A", "B")
350-
]
351-
loop = asyncio.get_event_loop()
352-
loop.run_until_complete(asyncio.gather(*tasks))
349+
[async_run_notebook(input_file.format(label=label), opts, res) for label in ("A", "B")]
353350

354351
captured = capfd.readouterr()
355352
assert filter_messages_on_error_output(captured.err) == ""
@@ -370,9 +367,7 @@ def test_many_async_parallel_notebooks(capfd):
370367
# run once, to trigger creating the original context
371368
run_notebook(input_file, opts, res)
372369

373-
tasks = [async_run_notebook(input_file, opts, res) for i in range(4)]
374-
loop = asyncio.get_event_loop()
375-
loop.run_until_complete(asyncio.gather(*tasks))
370+
[async_run_notebook(input_file, opts, res) for i in range(4)]
376371

377372
captured = capfd.readouterr()
378373
assert filter_messages_on_error_output(captured.err) == ""
@@ -741,6 +736,80 @@ def test_widgets(self):
741736
assert 'version_major' in wdata
742737
assert 'version_minor' in wdata
743738

739+
def test_execution_hook(self):
740+
filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')
741+
with open(filename) as f:
742+
input_nb = nbformat.read(f, 4)
743+
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
744+
executor = NotebookClient(
745+
input_nb,
746+
on_cell_start=hook1,
747+
on_cell_complete=hook2,
748+
on_cell_error=hook3,
749+
on_execution_start=hook4,
750+
)
751+
executor.execute()
752+
hook1.assert_called_once()
753+
hook2.assert_called_once()
754+
hook3.assert_not_called()
755+
hook4.assert_called_once()
756+
757+
def test_error_execution_hook_error(self):
758+
filename = os.path.join(current_dir, 'files', 'Error.ipynb')
759+
with open(filename) as f:
760+
input_nb = nbformat.read(f, 4)
761+
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
762+
executor = NotebookClient(
763+
input_nb,
764+
on_cell_start=hook1,
765+
on_cell_complete=hook2,
766+
on_cell_error=hook3,
767+
on_execution_start=hook4,
768+
)
769+
with pytest.raises(CellExecutionError):
770+
executor.execute()
771+
hook1.assert_called_once()
772+
hook2.assert_called_once()
773+
hook3.assert_called_once()
774+
hook4.assert_called_once()
775+
776+
def test_async_execution_hook(self):
777+
filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')
778+
with open(filename) as f:
779+
input_nb = nbformat.read(f, 4)
780+
hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock()
781+
executor = NotebookClient(
782+
input_nb,
783+
on_cell_start=hook1,
784+
on_cell_complete=hook2,
785+
on_cell_error=hook3,
786+
on_execution_start=hook4,
787+
)
788+
executor.execute()
789+
hook1.assert_called_once()
790+
hook2.assert_called_once()
791+
hook3.assert_not_called()
792+
hook4.assert_called_once()
793+
794+
def test_error_async_execution_hook(self):
795+
filename = os.path.join(current_dir, 'files', 'Error.ipynb')
796+
with open(filename) as f:
797+
input_nb = nbformat.read(f, 4)
798+
hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock()
799+
executor = NotebookClient(
800+
input_nb,
801+
on_cell_start=hook1,
802+
on_cell_complete=hook2,
803+
on_cell_error=hook3,
804+
on_execution_start=hook4,
805+
)
806+
with pytest.raises(CellExecutionError):
807+
executor.execute().execute()
808+
hook1.assert_called_once()
809+
hook2.assert_called_once()
810+
hook3.assert_called_once()
811+
hook4.assert_called_once()
812+
744813

745814
class TestRunCell(NBClientTestsBase):
746815
"""Contains test functions for NotebookClient.execute_cell"""
@@ -1524,3 +1593,81 @@ def test_no_source(self, executor, cell_mock, message_mock):
15241593
assert message_mock.call_count == 0
15251594
# Should also consume the message stream
15261595
assert cell_mock.outputs == []
1596+
1597+
@prepare_cell_mocks()
1598+
def test_cell_hooks(self, executor, cell_mock, message_mock):
1599+
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
1600+
executor.on_cell_start = hook1
1601+
executor.on_cell_complete = hook2
1602+
executor.on_cell_error = hook3
1603+
executor.on_execution_start = hook4
1604+
executor.execute_cell(cell_mock, 0)
1605+
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
1606+
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
1607+
hook3.assert_not_called()
1608+
hook4.assert_not_called()
1609+
1610+
@prepare_cell_mocks(
1611+
{
1612+
'msg_type': 'error',
1613+
'header': {'msg_type': 'error'},
1614+
'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']},
1615+
},
1616+
reply_msg={
1617+
'msg_type': 'execute_reply',
1618+
'header': {'msg_type': 'execute_reply'},
1619+
# ERROR
1620+
'content': {'status': 'error'},
1621+
},
1622+
)
1623+
def test_error_cell_hooks(self, executor, cell_mock, message_mock):
1624+
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
1625+
executor.on_cell_start = hook1
1626+
executor.on_cell_complete = hook2
1627+
executor.on_cell_error = hook3
1628+
executor.on_execution_start = hook4
1629+
with self.assertRaises(CellExecutionError):
1630+
executor.execute_cell(cell_mock, 0)
1631+
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
1632+
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
1633+
hook3.assert_called_once_with(cell=cell_mock, cell_index=0)
1634+
hook4.assert_not_called()
1635+
1636+
@prepare_cell_mocks()
1637+
def test_async_cell_hooks(self, executor, cell_mock, message_mock):
1638+
hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock()
1639+
executor.on_cell_start = hook1
1640+
executor.on_cell_complete = hook2
1641+
executor.on_cell_error = hook3
1642+
executor.on_execution_start = hook4
1643+
executor.execute_cell(cell_mock, 0)
1644+
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
1645+
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
1646+
hook3.assert_not_called()
1647+
hook4.assert_not_called()
1648+
1649+
@prepare_cell_mocks(
1650+
{
1651+
'msg_type': 'error',
1652+
'header': {'msg_type': 'error'},
1653+
'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']},
1654+
},
1655+
reply_msg={
1656+
'msg_type': 'execute_reply',
1657+
'header': {'msg_type': 'execute_reply'},
1658+
# ERROR
1659+
'content': {'status': 'error'},
1660+
},
1661+
)
1662+
def test_error_async_cell_hooks(self, executor, cell_mock, message_mock):
1663+
hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock()
1664+
executor.on_cell_start = hook1
1665+
executor.on_cell_complete = hook2
1666+
executor.on_cell_error = hook3
1667+
executor.on_execution_start = hook4
1668+
with self.assertRaises(CellExecutionError):
1669+
executor.execute_cell(cell_mock, 0)
1670+
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
1671+
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
1672+
hook3.assert_called_once_with(cell=cell_mock, cell_index=0)
1673+
hook4.assert_not_called()

0 commit comments

Comments
 (0)