Skip to content

Commit d28ce02

Browse files
Golf Playerdevintang3
authored andcommitted
Add basic hooks during execution
This will enable tracking of execution process without subclassing the way papermill does.
1 parent 612a9e8 commit d28ce02

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

nbclient/client.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
DeadKernelError,
2626
)
2727
from .output_widget import OutputWidget
28-
from .util import ensure_async, run_sync
28+
from .util import ensure_async, run_sync, run_hook
2929

3030

3131
def timestamp() -> str:
@@ -245,6 +245,45 @@ class NotebookClient(LoggingConfigurable):
245245

246246
kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.')
247247

248+
on_execution_start: t.Optional[t.Callable] = Any(
249+
default_value=None,
250+
allow_none=True,
251+
help=dedent("""
252+
Called after the kernel manager and kernel client are setup, and cells
253+
are about to execute.
254+
Called with kwargs `kernel_id`.
255+
"""),
256+
).tag(config=True)
257+
258+
on_cell_start: t.Optional[t.Callable] = Any(
259+
default_value=None,
260+
allow_none=True,
261+
help=dedent("""
262+
A callable which executes before a cell is executed.
263+
Called with kwargs `cell`, and `cell_index`.
264+
"""),
265+
).tag(config=True)
266+
267+
on_cell_complete: t.Optional[t.Callable] = Any(
268+
default_value=None,
269+
allow_none=True,
270+
help=dedent("""
271+
A callable which executes after a cell execution is complete. It is
272+
called even when a cell results in a failure.
273+
Called with kwargs `cell`, and `cell_index`.
274+
"""),
275+
).tag(config=True)
276+
277+
on_cell_error: t.Optional[t.Callable] = Any(
278+
default_value=None,
279+
allow_none=True,
280+
help=dedent("""
281+
A callable which executes when a cell execution results in an error.
282+
This is executed even if errors are suppressed with `cell_allows_errors`.
283+
Called with kwargs `cell`, and `cell_index`.
284+
"""),
285+
).tag(config=True)
286+
248287
@default('kernel_manager_class')
249288
def _kernel_manager_class_default(self) -> KernelManager:
250289
"""Use a dynamic default to avoid importing jupyter_client at startup"""
@@ -426,6 +465,7 @@ async def async_start_new_kernel_client(self) -> KernelClient:
426465
await self._async_cleanup_kernel()
427466
raise
428467
self.kc.allow_stdin = False
468+
run_hook(sself.on_execution_start)
429469
return self.kc
430470

431471
start_new_kernel_client = run_sync(async_start_new_kernel_client)
@@ -729,7 +769,11 @@ def _passed_deadline(self, deadline: int) -> bool:
729769
return True
730770
return False
731771

732-
def _check_raise_for_error(self, cell: NotebookNode, exec_reply: t.Optional[t.Dict]) -> None:
772+
def _check_raise_for_error(
773+
self,
774+
cell: NotebookNode,
775+
cell_index: int,
776+
exec_reply: t.Optional[t.Dict]) -> None:
733777

734778
if exec_reply is None:
735779
return None
@@ -744,8 +788,10 @@ def _check_raise_for_error(self, cell: NotebookNode, exec_reply: t.Optional[t.Di
744788
or "raises-exception" in cell.metadata.get("tags", [])
745789
)
746790

747-
if not cell_allows_errors:
748-
raise CellExecutionError.from_cell_and_msg(cell, exec_reply_content)
791+
if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
792+
run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
793+
if self.force_raise_errors or not cell_allows_errors:
794+
raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
749795

750796
async def async_execute_cell(
751797
self,
@@ -805,11 +851,13 @@ async def async_execute_cell(
805851
self.allow_errors or "raises-exception" in cell.metadata.get("tags", [])
806852
)
807853

854+
run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
808855
parent_msg_id = await ensure_async(
809856
self.kc.execute(
810857
cell.source, store_history=store_history, stop_on_error=not cell_allows_errors
811858
)
812859
)
860+
run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
813861
# We launched a code cell to execute
814862
self.code_cells_executed += 1
815863
exec_timeout = self._get_timeout(cell)
@@ -843,7 +891,7 @@ async def async_execute_cell(
843891

844892
if execution_count:
845893
cell['execution_count'] = execution_count
846-
self._check_raise_for_error(cell, exec_reply)
894+
self._check_raise_for_error(cell, cell_index, exec_reply)
847895
self.nb['cells'][cell_index] = cell
848896
return cell
849897

nbclient/util.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import asyncio
77
import inspect
88
import sys
9-
from typing import Any, Awaitable, Callable, Union
9+
from typing import Any, Awaitable, Callable, Optional, Union
10+
from functools import partial
1011

1112

1213
def check_ipython() -> None:
@@ -102,3 +103,15 @@ async def ensure_async(obj: Union[Awaitable, Any]) -> Any:
102103
return result
103104
# obj doesn't need to be awaited
104105
return obj
106+
107+
108+
def run_hook(hook: Optional[Callable], **kwargs) -> None:
109+
if hook is None:
110+
return
111+
if inspect.iscoroutinefunction(hook):
112+
future = hook(**kwargs)
113+
else:
114+
loop = asyncio.get_event_loop()
115+
hook_with_kwargs = partial(hook, **kwargs)
116+
future = loop.run_in_executor(None, hook_with_kwargs)
117+
asyncio.ensure_future(future)

0 commit comments

Comments
 (0)