Skip to content

Commit c68b700

Browse files
Golf Playerdevintang3
authored andcommitted
Add basic hooks during execution
This will enable tracking of execution process without subclassing the way papermill does.
1 parent 74d0a6e commit c68b700

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

nbclient/client.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
DeadKernelError,
2727
)
2828
from .output_widget import OutputWidget
29-
from .util import ensure_async, run_sync
29+
from .util import ensure_async, run_sync, run_hook
3030

3131

3232
def timestamp(msg: Optional[Dict] = None) -> str:
@@ -261,6 +261,45 @@ class NotebookClient(LoggingConfigurable):
261261

262262
kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.')
263263

264+
on_execution_start: t.Optional[t.Callable] = Any(
265+
default_value=None,
266+
allow_none=True,
267+
help=dedent("""
268+
Called after the kernel manager and kernel client are setup, and cells
269+
are about to execute.
270+
Called with kwargs `kernel_id`.
271+
"""),
272+
).tag(config=True)
273+
274+
on_cell_start: t.Optional[t.Callable] = Any(
275+
default_value=None,
276+
allow_none=True,
277+
help=dedent("""
278+
A callable which executes before a cell is executed.
279+
Called with kwargs `cell`, and `cell_index`.
280+
"""),
281+
).tag(config=True)
282+
283+
on_cell_complete: t.Optional[t.Callable] = Any(
284+
default_value=None,
285+
allow_none=True,
286+
help=dedent("""
287+
A callable which executes after a cell execution is complete. It is
288+
called even when a cell results in a failure.
289+
Called with kwargs `cell`, and `cell_index`.
290+
"""),
291+
).tag(config=True)
292+
293+
on_cell_error: t.Optional[t.Callable] = Any(
294+
default_value=None,
295+
allow_none=True,
296+
help=dedent("""
297+
A callable which executes when a cell execution results in an error.
298+
This is executed even if errors are suppressed with `cell_allows_errors`.
299+
Called with kwargs `cell`, and `cell_index`.
300+
"""),
301+
).tag(config=True)
302+
264303
@default('kernel_manager_class')
265304
def _kernel_manager_class_default(self) -> KernelManager:
266305
"""Use a dynamic default to avoid importing jupyter_client at startup"""
@@ -442,6 +481,7 @@ async def async_start_new_kernel_client(self) -> KernelClient:
442481
await self._async_cleanup_kernel()
443482
raise
444483
self.kc.allow_stdin = False
484+
run_hook(sself.on_execution_start)
445485
return self.kc
446486

447487
start_new_kernel_client = run_sync(async_start_new_kernel_client)
@@ -745,7 +785,11 @@ def _passed_deadline(self, deadline: int) -> bool:
745785
return True
746786
return False
747787

748-
def _check_raise_for_error(self, cell: NotebookNode, exec_reply: t.Optional[t.Dict]) -> None:
788+
def _check_raise_for_error(
789+
self,
790+
cell: NotebookNode,
791+
cell_index: int,
792+
exec_reply: t.Optional[t.Dict]) -> None:
749793

750794
if exec_reply is None:
751795
return None
@@ -760,8 +804,10 @@ def _check_raise_for_error(self, cell: NotebookNode, exec_reply: t.Optional[t.Di
760804
or "raises-exception" in cell.metadata.get("tags", [])
761805
)
762806

763-
if not cell_allows_errors:
764-
raise CellExecutionError.from_cell_and_msg(cell, exec_reply_content)
807+
if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
808+
run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
809+
if self.force_raise_errors or not cell_allows_errors:
810+
raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
765811

766812
async def async_execute_cell(
767813
self,
@@ -821,11 +867,13 @@ async def async_execute_cell(
821867
self.allow_errors or "raises-exception" in cell.metadata.get("tags", [])
822868
)
823869

870+
run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
824871
parent_msg_id = await ensure_async(
825872
self.kc.execute(
826873
cell.source, store_history=store_history, stop_on_error=not cell_allows_errors
827874
)
828875
)
876+
run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
829877
# We launched a code cell to execute
830878
self.code_cells_executed += 1
831879
exec_timeout = self._get_timeout(cell)
@@ -859,7 +907,7 @@ async def async_execute_cell(
859907

860908
if execution_count:
861909
cell['execution_count'] = execution_count
862-
self._check_raise_for_error(cell, exec_reply)
910+
self._check_raise_for_error(cell, cell_index, exec_reply)
863911
self.nb['cells'][cell_index] = cell
864912
return cell
865913

nbclient/util.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import asyncio
77
import inspect
88
import sys
9-
from typing import Any, Awaitable, Callable, Union
9+
from typing import Any, Awaitable, Callable, Optional, Union
10+
from functools import partial
1011

1112

1213
def check_ipython() -> None:
@@ -102,3 +103,15 @@ async def ensure_async(obj: Union[Awaitable, Any]) -> Any:
102103
return result
103104
# obj doesn't need to be awaited
104105
return obj
106+
107+
108+
def run_hook(hook: Optional[Callable], **kwargs) -> None:
109+
if hook is None:
110+
return
111+
if inspect.iscoroutinefunction(hook):
112+
future = hook(**kwargs)
113+
else:
114+
loop = asyncio.get_event_loop()
115+
hook_with_kwargs = partial(hook, **kwargs)
116+
future = loop.run_in_executor(None, hook_with_kwargs)
117+
asyncio.ensure_future(future)

0 commit comments

Comments
 (0)