diff --git a/nbclient/client.py b/nbclient/client.py index 936bb4a..30f1120 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -325,6 +325,17 @@ class NotebookClient(LoggingConfigurable): ), ).tag(config=True) + on_cell_input_request = Callable( + default_value=None, + allow_none=True, + help=dedent( + """ + A callable which executes when a cell requests input. + Called with kwargs ``cell``, ``cell_index``, and ``input_request``. + """ + ), + ) + on_cell_start = Callable( default_value=None, allow_none=True, @@ -572,7 +583,7 @@ async def async_start_new_kernel_client(self) -> KernelClient: ) await self._async_cleanup_kernel() raise - self.kc.allow_stdin = False + self.kc.allow_stdin = self.on_cell_input_request is not None await run_hook(self.on_notebook_start, notebook=self.nb) return self.kc @@ -758,6 +769,32 @@ def _update_display_id(self, display_id: str, msg: dict[str, t.Any]) -> None: outputs[output_idx]["data"] = out["data"] outputs[output_idx]["metadata"] = out["metadata"] + async def _async_poll_stdin_msg( + self, parent_msg_id: str, cell: NotebookNode, cell_index: int + ) -> None: + """Poll for stdin messages (input requests) from the kernel. + + This method runs in parallel with _async_poll_output_msg and handles + input requests by calling the on_cell_input_request callback and + sending the response back to the kernel. + """ + assert self.kc is not None + + while True: + try: + msg = await ensure_async(self.kc.stdin_channel.get_msg(timeout=self.iopub_timeout)) + if msg["parent_header"].get("msg_id") == parent_msg_id: + if msg["header"]["msg_type"] == "input_request": + response = await ensure_async( + self.on_cell_input_request( + cell=cell, cell_index=cell_index, input_request=msg + ) + ) + self.kc.input(response) + except Empty: + # Yield control to allow cancellation to be processed + await asyncio.sleep(0.01) + async def _async_poll_for_reply( self, msg_id: str, @@ -996,6 +1033,14 @@ async def async_execute_cell( task_poll_output_msg = asyncio.ensure_future( self._async_poll_output_msg(parent_msg_id, cell, cell_index) ) + + # Create stdin polling task if input handling is enabled + task_poll_stdin_msg = None + if self.on_cell_input_request is not None: + task_poll_stdin_msg = asyncio.ensure_future( + self._async_poll_stdin_msg(parent_msg_id, cell, cell_index) + ) + self.task_poll_for_reply = asyncio.ensure_future( self._async_poll_for_reply( parent_msg_id, cell, exec_timeout, task_poll_output_msg, task_poll_kernel_alive @@ -1006,6 +1051,8 @@ async def async_execute_cell( except asyncio.CancelledError: # can only be cancelled by task_poll_kernel_alive when the kernel is dead task_poll_output_msg.cancel() + if task_poll_stdin_msg is not None: + task_poll_stdin_msg.cancel() raise DeadKernelError("Kernel died") from None except Exception as e: # Best effort to cancel request if it hasn't been resolved @@ -1013,9 +1060,15 @@ async def async_execute_cell( # Check if the task_poll_output is doing the raising for us if not isinstance(e, CellControlSignal): task_poll_output_msg.cancel() + if task_poll_stdin_msg is not None: + task_poll_stdin_msg.cancel() finally: raise + # Cancel stdin task after successful execution + if task_poll_stdin_msg is not None: + task_poll_stdin_msg.cancel() + if execution_count: cell["execution_count"] = execution_count await run_hook( diff --git a/tests/files/InputRequest.ipynb b/tests/files/InputRequest.ipynb new file mode 100644 index 0000000..6c2ef7d --- /dev/null +++ b/tests/files/InputRequest.ipynb @@ -0,0 +1,21 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "name = input(\"What is your name? \")\n", + "print(f\"Hello, {name}!\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tests/test_client.py b/tests/test_client.py index 3bf06d8..f155cd6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -57,6 +57,7 @@ "on_cell_complete", "on_cell_executed", "on_cell_error", + "on_cell_input_request", "on_notebook_start", "on_notebook_complete", "on_notebook_error", @@ -238,8 +239,10 @@ class NotebookClientWithParentID(NotebookClient): executor.kc = MagicMock( iopub_channel=MagicMock(get_msg=message_mock), shell_channel=MagicMock(get_msg=shell_channel_message_mock()), + stdin_channel=MagicMock(get_msg=AsyncMock(side_effect=Empty())), execute=MagicMock(return_value=parent_id), is_alive=MagicMock(return_value=make_future(True)), + input=MagicMock(), ) executor.parent_id = parent_id return func(self, executor, cell_mock, message_mock) @@ -901,6 +904,7 @@ def test_execution_hook(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_called_once() hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_not_called() @@ -917,6 +921,7 @@ def test_error_execution_hook_error(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_called_once() hooks["on_cell_error"].assert_called_once() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_not_called() @@ -933,6 +938,7 @@ def test_error_notebook_hook(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_not_called() hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_called_once() @@ -948,6 +954,7 @@ def test_async_execution_hook(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_called_once() hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_not_called() @@ -964,10 +971,38 @@ def test_error_async_execution_hook(self): hooks["on_cell_complete"].assert_called_once() hooks["on_cell_executed"].assert_called_once() hooks["on_cell_error"].assert_called_once() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_called_once() hooks["on_notebook_complete"].assert_called_once() hooks["on_notebook_error"].assert_not_called() + def test_input_request_hook(self): + """Test that on_cell_input_request hook is called when cell requests input""" + filename = os.path.join(current_dir, "files", "InputRequest.ipynb") + with open(filename) as f: + input_nb = nbformat.read(f, 4) + executor, hooks = get_executor_with_hooks(nb=input_nb) + + # Set up the input request hook to return a mock response + hooks["on_cell_input_request"].return_value = "Test User" + + executor.execute() + hooks["on_cell_start"].assert_called_once() + hooks["on_cell_execute"].assert_called_once() + hooks["on_cell_complete"].assert_called_once() + hooks["on_cell_executed"].assert_called_once() + hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_called_once() + hooks["on_notebook_start"].assert_called_once() + hooks["on_notebook_complete"].assert_called_once() + hooks["on_notebook_error"].assert_not_called() + + # Verify the callback was called with correct arguments + call_args = hooks["on_cell_input_request"].call_args + assert call_args[1]["cell"] == input_nb.cells[0] # cell argument + assert call_args[1]["cell_index"] == 0 # cell_index argument + assert "input_request" in call_args[1] # input_request argument + class TestRunCell(NBClientTestsBase): """Contains test functions for NotebookClient.execute_cell""" @@ -1763,6 +1798,7 @@ def test_cell_hooks(self, executor, cell_mock, message_mock): cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_OK ) hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_not_called() hooks["on_notebook_complete"].assert_not_called() hooks["on_notebook_error"].assert_not_called() @@ -1793,6 +1829,7 @@ def test_error_cell_hooks(self, executor, cell_mock, message_mock): hooks["on_cell_error"].assert_called_once_with( cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_ERROR ) + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_not_called() hooks["on_notebook_complete"].assert_not_called() hooks["on_notebook_error"].assert_not_called() @@ -1829,6 +1866,7 @@ def test_async_cell_hooks(self, executor, cell_mock, message_mock): cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_OK ) hooks["on_cell_error"].assert_not_called() + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_not_called() hooks["on_notebook_complete"].assert_not_called() hooks["on_notebook_error"].assert_not_called() @@ -1859,6 +1897,7 @@ def test_error_async_cell_hooks(self, executor, cell_mock, message_mock): hooks["on_cell_error"].assert_called_once_with( cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_ERROR ) + hooks["on_cell_input_request"].assert_not_called() hooks["on_notebook_start"].assert_not_called() hooks["on_notebook_complete"].assert_not_called() hooks["on_notebook_error"].assert_not_called()