From 6ad4a43fa30e0e313bff1025404937de3c74d40f Mon Sep 17 00:00:00 2001 From: Dan Grahn Date: Thu, 17 Jul 2025 07:47:59 -0400 Subject: [PATCH 1/6] #329 Add on_cell_input_request handler --- nbclient/client.py | 54 +++++++++++++++++++++++++++++++++- tests/files/InputRequest.ipynb | 21 +++++++++++++ tests/test_client.py | 38 ++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 tests/files/InputRequest.ipynb diff --git a/nbclient/client.py b/nbclient/client.py index 936bb4a..448fbb1 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -325,6 +325,17 @@ class NotebookClient(LoggingConfigurable): ), ).tag(config=True) + on_cell_input_request = Callable( + default_value=None, + allow_none=True, + help=dedent( + """ + A callable which executes when a cell requests input. + Called with kwargs ``cell`` and ``cell_index``. + """ + ), + ) + on_cell_start = Callable( default_value=None, allow_none=True, @@ -572,7 +583,7 @@ async def async_start_new_kernel_client(self) -> KernelClient: ) await self._async_cleanup_kernel() raise - self.kc.allow_stdin = False + self.kc.allow_stdin = self.on_cell_input_request is not None await run_hook(self.on_notebook_start, notebook=self.nb) return self.kc @@ -758,6 +769,31 @@ def _update_display_id(self, display_id: str, msg: dict[str, t.Any]) -> None: outputs[output_idx]["data"] = out["data"] outputs[output_idx]["metadata"] = out["metadata"] + async def _async_poll_stdin_msg( + self, parent_msg_id: str, cell: NotebookNode, cell_index: int + ) -> None: + """Poll for stdin messages (input requests) from the kernel. + + This method runs in parallel with _async_poll_output_msg and handles + input requests by calling the on_cell_input_request callback and + sending the response back to the kernel. + """ + assert self.kc is not None + + while True: + try: + msg = await ensure_async(self.kc.stdin_channel.get_msg(timeout=None)) + if msg["parent_header"].get("msg_id") == parent_msg_id: + if msg["header"]["msg_type"] == "input_request": + response = await ensure_async( + self.on_cell_input_request(cell=cell, cell_index=cell_index) + ) + self.kc.input(response) + except Empty: + # Yield control to allow cancellation to be processed + await asyncio.sleep(0.01) + continue + async def _async_poll_for_reply( self, msg_id: str, @@ -996,6 +1032,14 @@ async def async_execute_cell( task_poll_output_msg = asyncio.ensure_future( self._async_poll_output_msg(parent_msg_id, cell, cell_index) ) + + # Create stdin polling task if input handling is enabled + task_poll_stdin_msg = None + if self.on_cell_input_request is not None: + task_poll_stdin_msg = asyncio.ensure_future( + self._async_poll_stdin_msg(parent_msg_id, cell, cell_index) + ) + self.task_poll_for_reply = asyncio.ensure_future( self._async_poll_for_reply( parent_msg_id, cell, exec_timeout, task_poll_output_msg, task_poll_kernel_alive @@ -1006,6 +1050,8 @@ async def async_execute_cell( except asyncio.CancelledError: # can only be cancelled by task_poll_kernel_alive when the kernel is dead task_poll_output_msg.cancel() + if task_poll_stdin_msg is not None: + task_poll_stdin_msg.cancel() raise DeadKernelError("Kernel died") from None except Exception as e: # Best effort to cancel request if it hasn't been resolved @@ -1013,9 +1059,15 @@ async def async_execute_cell( # Check if the task_poll_output is doing the raising for us if not isinstance(e, CellControlSignal): task_poll_output_msg.cancel() + if task_poll_stdin_msg is not None: + task_poll_stdin_msg.cancel() finally: raise + # Cancel stdin task after successful execution + if task_poll_stdin_msg is not None: + task_poll_stdin_msg.cancel() + if execution_count: cell["execution_count"] = execution_count await run_hook( diff --git a/tests/files/InputRequest.ipynb b/tests/files/InputRequest.ipynb new file mode 100644 index 0000000..6c2ef7d --- /dev/null +++ b/tests/files/InputRequest.ipynb @@ -0,0 +1,21 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "name = input(\"What is your name? \")\n", + "print(f\"Hello, {name}!\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tests/test_client.py b/tests/test_client.py index 3bf06d8..722a731 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -57,6 +57,7 @@ "on_cell_complete", "on_cell_executed", "on_cell_error", + "on_cell_input_request", "on_notebook_start", "on_notebook_complete", "on_notebook_error", @@ -238,8 +239,10 @@ class NotebookClientWithParentID(NotebookClient): executor.kc = MagicMock( iopub_channel=MagicMock(get_msg=message_mock), shell_channel=MagicMock(get_msg=shell_channel_message_mock()), + stdin_channel=MagicMock(get_msg=AsyncMock(side_effect=Empty())), execute=MagicMock(return_value=parent_id), is_alive=MagicMock(return_value=make_future(True)), + input=MagicMock(), ) executor.parent_id = parent_id return func(self, executor, cell_mock, message_mock) @@ -901,6 +904,7 @@ def test_execution_hook(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_called_once() hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_not_called() @@ -917,6 +921,7 @@ def test_error_execution_hook_error(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_called_once() hooks["on_cell_error"].assert_called_once() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_not_called() @@ -933,6 +938,7 @@ def test_error_notebook_hook(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_not_called() hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_called_once() @@ -948,6 +954,7 @@ def test_async_execution_hook(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_called_once() hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_not_called() @@ -964,10 +971,37 @@ def test_error_async_execution_hook(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_called_once() hooks["on_cell_error"].assert_called_once() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_not_called() + def test_input_request_hook(self): + """Test that on_cell_input_request hook is called when cell requests input""" + filename = os.path.join(current_dir, "files", "InputRequest.ipynb") + with open(filename) as f: + input_nb = nbformat.read(f, 4) + executor, hooks = get_executor_with_hooks(nb=input_nb) + + # Set up the input request hook to return a mock response + hooks["on_cell_input_request"].return_value = "Test User" + + executor.execute() + hooks["on_cell_start"].assert_called_once() + hooks["on_cell_execute"].assert_called_once() + hooks["on_cell_complete"].assert_called_once() + hooks["on_cell_executed"].assert_called_once() + hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_called_once() + hooks["on_notebook_start"].assert_called_once() + hooks["on_notebook_complete"].assert_called_once() + hooks["on_notebook_error"].assert_not_called() + + # Verify the callback was called with correct arguments + call_args = hooks["on_cell_input_request"].call_args + assert call_args[1]["cell"] == input_nb.cells[0] # cell argument + assert call_args[1]["cell_index"] == 0 # cell_index argument + class TestRunCell(NBClientTestsBase): """Contains test functions for NotebookClient.execute_cell""" @@ -1763,6 +1797,7 @@ def test_cell_hooks(self, executor, cell_mock, message_mock): cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_OK ) hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_not_called() hooks["on_notebook_complete"].assert_not_called() hooks["on_notebook_error"].assert_not_called() @@ -1793,6 +1828,7 @@ def test_error_cell_hooks(self, executor, cell_mock, message_mock): hooks["on_cell_error"].assert_called_once_with( cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_ERROR ) + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_not_called() hooks["on_notebook_complete"].assert_not_called() hooks["on_notebook_error"].assert_not_called() @@ -1829,6 +1865,7 @@ def test_async_cell_hooks(self, executor, cell_mock, message_mock): cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_OK ) hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_not_called() hooks["on_notebook_complete"].assert_not_called() hooks["on_notebook_error"].assert_not_called() @@ -1859,6 +1896,7 @@ def test_error_async_cell_hooks(self, executor, cell_mock, message_mock): hooks["on_cell_error"].assert_called_once_with( cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_ERROR ) + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_not_called() hooks["on_notebook_complete"].assert_not_called() hooks["on_notebook_error"].assert_not_called() From cba24204e7728e9143eb5b3f5008d142e69c6e5b Mon Sep 17 00:00:00 2001 From: Dan Grahn Date: Thu, 17 Jul 2025 08:29:54 -0400 Subject: [PATCH 2/6] Add input_message to callback --- nbclient/client.py | 4 ++-- tests/test_client.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nbclient/client.py b/nbclient/client.py index 448fbb1..ae205e2 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -331,7 +331,7 @@ class NotebookClient(LoggingConfigurable): help=dedent( """ A callable which executes when a cell requests input. - Called with kwargs ``cell`` and ``cell_index``. + Called with kwargs ``cell``, ``cell_index``, and ``input_request``. """ ), ) @@ -786,7 +786,7 @@ async def _async_poll_stdin_msg( if msg["parent_header"].get("msg_id") == parent_msg_id: if msg["header"]["msg_type"] == "input_request": response = await ensure_async( - self.on_cell_input_request(cell=cell, cell_index=cell_index) + self.on_cell_input_request(cell=cell, cell_index=cell_index, input_request=msg) ) self.kc.input(response) except Empty: diff --git a/tests/test_client.py b/tests/test_client.py index 722a731..f155cd6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1001,6 +1001,7 @@ def test_input_request_hook(self): call_args = hooks["on_cell_input_request"].call_args assert call_args[1]["cell"] == input_nb.cells[0] # cell argument assert call_args[1]["cell_index"] == 0 # cell_index argument + assert "input_request" in call_args[1] # input_request argument class TestRunCell(NBClientTestsBase): From a208754639658922a89bcda89b0031733706ff6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jul 2025 12:47:08 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nbclient/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nbclient/client.py b/nbclient/client.py index ae205e2..ea362be 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -786,7 +786,9 @@ async def _async_poll_stdin_msg( if msg["parent_header"].get("msg_id") == parent_msg_id: if msg["header"]["msg_type"] == "input_request": response = await ensure_async( - self.on_cell_input_request(cell=cell, cell_index=cell_index, input_request=msg) + self.on_cell_input_request( + cell=cell, cell_index=cell_index, input_request=msg + ) ) self.kc.input(response) except Empty: From c17688db72f4095db138d475cae6a53d01733d32 Mon Sep 17 00:00:00 2001 From: Dan Grahn Date: Fri, 18 Jul 2025 07:41:07 -0400 Subject: [PATCH 4/6] Add timeout, remove continue --- nbclient/client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nbclient/client.py b/nbclient/client.py index ea362be..30f1120 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -782,7 +782,7 @@ async def _async_poll_stdin_msg( while True: try: - msg = await ensure_async(self.kc.stdin_channel.get_msg(timeout=None)) + msg = await ensure_async(self.kc.stdin_channel.get_msg(timeout=self.iopub_timeout)) if msg["parent_header"].get("msg_id") == parent_msg_id: if msg["header"]["msg_type"] == "input_request": response = await ensure_async( @@ -794,7 +794,6 @@ async def _async_poll_stdin_msg( except Empty: # Yield control to allow cancellation to be processed await asyncio.sleep(0.01) - continue async def _async_poll_for_reply( self, From c3d2179cea67d8b7d648bed39388c78af5299a69 Mon Sep 17 00:00:00 2001 From: Dan Grahn Date: Fri, 25 Jul 2025 10:16:09 -0400 Subject: [PATCH 5/6] Convert ensure_future to create_task --- nbclient/client.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nbclient/client.py b/nbclient/client.py index 30f1120..6cee2a0 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -643,7 +643,7 @@ async def async_setup_kernel(self, **kwargs: t.Any) -> t.AsyncGenerator[None, No def on_signal() -> None: """Handle signals.""" - self._async_cleanup_kernel_future = asyncio.ensure_future(self._async_cleanup_kernel()) + self._async_cleanup_kernel_future = asyncio.create_task(self._async_cleanup_kernel()) atexit.unregister(self._cleanup_kernel) loop = asyncio.get_event_loop() @@ -1029,19 +1029,19 @@ async def async_execute_cell( cell.outputs = [] self.clear_before_next_output = False - task_poll_kernel_alive = asyncio.ensure_future(self._async_poll_kernel_alive()) - task_poll_output_msg = asyncio.ensure_future( + task_poll_kernel_alive = asyncio.create_task(self._async_poll_kernel_alive()) + task_poll_output_msg = asyncio.create_task( self._async_poll_output_msg(parent_msg_id, cell, cell_index) ) # Create stdin polling task if input handling is enabled task_poll_stdin_msg = None if self.on_cell_input_request is not None: - task_poll_stdin_msg = asyncio.ensure_future( + task_poll_stdin_msg = asyncio.create_task( self._async_poll_stdin_msg(parent_msg_id, cell, cell_index) ) - self.task_poll_for_reply = asyncio.ensure_future( + self.task_poll_for_reply = asyncio.create_task( self._async_poll_for_reply( parent_msg_id, cell, exec_timeout, task_poll_output_msg, task_poll_kernel_alive ) From f4b22380a3fd13ee99c76238e9e62ea48b90cf61 Mon Sep 17 00:00:00 2001 From: Dan Grahn Date: Fri, 25 Jul 2025 10:38:39 -0400 Subject: [PATCH 6/6] Undo create_task conversion, causes test errors. --- nbclient/client.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nbclient/client.py b/nbclient/client.py index 6cee2a0..30f1120 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -643,7 +643,7 @@ async def async_setup_kernel(self, **kwargs: t.Any) -> t.AsyncGenerator[None, No def on_signal() -> None: """Handle signals.""" - self._async_cleanup_kernel_future = asyncio.create_task(self._async_cleanup_kernel()) + self._async_cleanup_kernel_future = asyncio.ensure_future(self._async_cleanup_kernel()) atexit.unregister(self._cleanup_kernel) loop = asyncio.get_event_loop() @@ -1029,19 +1029,19 @@ async def async_execute_cell( cell.outputs = [] self.clear_before_next_output = False - task_poll_kernel_alive = asyncio.create_task(self._async_poll_kernel_alive()) - task_poll_output_msg = asyncio.create_task( + task_poll_kernel_alive = asyncio.ensure_future(self._async_poll_kernel_alive()) + task_poll_output_msg = asyncio.ensure_future( self._async_poll_output_msg(parent_msg_id, cell, cell_index) ) # Create stdin polling task if input handling is enabled task_poll_stdin_msg = None if self.on_cell_input_request is not None: - task_poll_stdin_msg = asyncio.create_task( + task_poll_stdin_msg = asyncio.ensure_future( self._async_poll_stdin_msg(parent_msg_id, cell, cell_index) ) - self.task_poll_for_reply = asyncio.create_task( + self.task_poll_for_reply = asyncio.ensure_future( self._async_poll_for_reply( parent_msg_id, cell, exec_timeout, task_poll_output_msg, task_poll_kernel_alive )