Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion nbclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,17 @@ class NotebookClient(LoggingConfigurable):
),
).tag(config=True)

on_cell_input_request = Callable(
default_value=None,
allow_none=True,
help=dedent(
"""
A callable which executes when a cell requests input.
Called with kwargs ``cell``, ``cell_index``, and ``input_request``.
"""
),
)

on_cell_start = Callable(
default_value=None,
allow_none=True,
Expand Down Expand Up @@ -572,7 +583,7 @@ async def async_start_new_kernel_client(self) -> KernelClient:
)
await self._async_cleanup_kernel()
raise
self.kc.allow_stdin = False
self.kc.allow_stdin = self.on_cell_input_request is not None
await run_hook(self.on_notebook_start, notebook=self.nb)
return self.kc

Expand Down Expand Up @@ -758,6 +769,33 @@ def _update_display_id(self, display_id: str, msg: dict[str, t.Any]) -> None:
outputs[output_idx]["data"] = out["data"]
outputs[output_idx]["metadata"] = out["metadata"]

async def _async_poll_stdin_msg(
self, parent_msg_id: str, cell: NotebookNode, cell_index: int
) -> None:
"""Poll for stdin messages (input requests) from the kernel.

This method runs in parallel with _async_poll_output_msg and handles
input requests by calling the on_cell_input_request callback and
sending the response back to the kernel.
"""
assert self.kc is not None

while True:
try:
msg = await ensure_async(self.kc.stdin_channel.get_msg(timeout=None))
if msg["parent_header"].get("msg_id") == parent_msg_id:
if msg["header"]["msg_type"] == "input_request":
response = await ensure_async(
self.on_cell_input_request(
cell=cell, cell_index=cell_index, input_request=msg
)
)
self.kc.input(response)
except Empty:
# Yield control to allow cancellation to be processed
await asyncio.sleep(0.01)
continue

async def _async_poll_for_reply(
self,
msg_id: str,
Expand Down Expand Up @@ -996,6 +1034,14 @@ async def async_execute_cell(
task_poll_output_msg = asyncio.ensure_future(
self._async_poll_output_msg(parent_msg_id, cell, cell_index)
)

# Create stdin polling task if input handling is enabled
task_poll_stdin_msg = None
if self.on_cell_input_request is not None:
task_poll_stdin_msg = asyncio.ensure_future(
self._async_poll_stdin_msg(parent_msg_id, cell, cell_index)
)

self.task_poll_for_reply = asyncio.ensure_future(
self._async_poll_for_reply(
parent_msg_id, cell, exec_timeout, task_poll_output_msg, task_poll_kernel_alive
Expand All @@ -1006,16 +1052,24 @@ async def async_execute_cell(
except asyncio.CancelledError:
# can only be cancelled by task_poll_kernel_alive when the kernel is dead
task_poll_output_msg.cancel()
if task_poll_stdin_msg is not None:
task_poll_stdin_msg.cancel()
raise DeadKernelError("Kernel died") from None
except Exception as e:
# Best effort to cancel request if it hasn't been resolved
try:
# Check if the task_poll_output is doing the raising for us
if not isinstance(e, CellControlSignal):
task_poll_output_msg.cancel()
if task_poll_stdin_msg is not None:
task_poll_stdin_msg.cancel()
finally:
raise

# Cancel stdin task after successful execution
if task_poll_stdin_msg is not None:
task_poll_stdin_msg.cancel()

if execution_count:
cell["execution_count"] = execution_count
await run_hook(
Expand Down
21 changes: 21 additions & 0 deletions tests/files/InputRequest.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"name = input(\"What is your name? \")\n",
"print(f\"Hello, {name}!\")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
39 changes: 39 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"on_cell_complete",
"on_cell_executed",
"on_cell_error",
"on_cell_input_request",
"on_notebook_start",
"on_notebook_complete",
"on_notebook_error",
Expand Down Expand Up @@ -238,8 +239,10 @@ class NotebookClientWithParentID(NotebookClient):
executor.kc = MagicMock(
iopub_channel=MagicMock(get_msg=message_mock),
shell_channel=MagicMock(get_msg=shell_channel_message_mock()),
stdin_channel=MagicMock(get_msg=AsyncMock(side_effect=Empty())),
execute=MagicMock(return_value=parent_id),
is_alive=MagicMock(return_value=make_future(True)),
input=MagicMock(),
)
executor.parent_id = parent_id
return func(self, executor, cell_mock, message_mock)
Expand Down Expand Up @@ -901,6 +904,7 @@ def test_execution_hook(self):
hooks["on_cell_complete"].assert_called_once()
hooks["on_cell_executed"].assert_called_once()
hooks["on_cell_error"].assert_not_called()
hooks["on_cell_input_request"].assert_not_called()
hooks["on_notebook_start"].assert_called_once()
hooks["on_notebook_complete"].assert_called_once()
hooks["on_notebook_error"].assert_not_called()
Expand All @@ -917,6 +921,7 @@ def test_error_execution_hook_error(self):
hooks["on_cell_complete"].assert_called_once()
hooks["on_cell_executed"].assert_called_once()
hooks["on_cell_error"].assert_called_once()
hooks["on_cell_input_request"].assert_not_called()
hooks["on_notebook_start"].assert_called_once()
hooks["on_notebook_complete"].assert_called_once()
hooks["on_notebook_error"].assert_not_called()
Expand All @@ -933,6 +938,7 @@ def test_error_notebook_hook(self):
hooks["on_cell_complete"].assert_called_once()
hooks["on_cell_executed"].assert_not_called()
hooks["on_cell_error"].assert_not_called()
hooks["on_cell_input_request"].assert_not_called()
hooks["on_notebook_start"].assert_called_once()
hooks["on_notebook_complete"].assert_called_once()
hooks["on_notebook_error"].assert_called_once()
Expand All @@ -948,6 +954,7 @@ def test_async_execution_hook(self):
hooks["on_cell_complete"].assert_called_once()
hooks["on_cell_executed"].assert_called_once()
hooks["on_cell_error"].assert_not_called()
hooks["on_cell_input_request"].assert_not_called()
hooks["on_notebook_start"].assert_called_once()
hooks["on_notebook_complete"].assert_called_once()
hooks["on_notebook_error"].assert_not_called()
Expand All @@ -964,10 +971,38 @@ def test_error_async_execution_hook(self):
hooks["on_cell_complete"].assert_called_once()
hooks["on_cell_executed"].assert_called_once()
hooks["on_cell_error"].assert_called_once()
hooks["on_cell_input_request"].assert_not_called()
hooks["on_notebook_start"].assert_called_once()
hooks["on_notebook_complete"].assert_called_once()
hooks["on_notebook_error"].assert_not_called()

def test_input_request_hook(self):
"""Test that on_cell_input_request hook is called when cell requests input"""
filename = os.path.join(current_dir, "files", "InputRequest.ipynb")
with open(filename) as f:
input_nb = nbformat.read(f, 4)
executor, hooks = get_executor_with_hooks(nb=input_nb)

# Set up the input request hook to return a mock response
hooks["on_cell_input_request"].return_value = "Test User"

executor.execute()
hooks["on_cell_start"].assert_called_once()
hooks["on_cell_execute"].assert_called_once()
hooks["on_cell_complete"].assert_called_once()
hooks["on_cell_executed"].assert_called_once()
hooks["on_cell_error"].assert_not_called()
hooks["on_cell_input_request"].assert_called_once()
hooks["on_notebook_start"].assert_called_once()
hooks["on_notebook_complete"].assert_called_once()
hooks["on_notebook_error"].assert_not_called()

# Verify the callback was called with correct arguments
call_args = hooks["on_cell_input_request"].call_args
assert call_args[1]["cell"] == input_nb.cells[0] # cell argument
assert call_args[1]["cell_index"] == 0 # cell_index argument
assert "input_request" in call_args[1] # input_request argument


class TestRunCell(NBClientTestsBase):
"""Contains test functions for NotebookClient.execute_cell"""
Expand Down Expand Up @@ -1763,6 +1798,7 @@ def test_cell_hooks(self, executor, cell_mock, message_mock):
cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_OK
)
hooks["on_cell_error"].assert_not_called()
hooks["on_cell_input_request"].assert_not_called()
hooks["on_notebook_start"].assert_not_called()
hooks["on_notebook_complete"].assert_not_called()
hooks["on_notebook_error"].assert_not_called()
Expand Down Expand Up @@ -1793,6 +1829,7 @@ def test_error_cell_hooks(self, executor, cell_mock, message_mock):
hooks["on_cell_error"].assert_called_once_with(
cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_ERROR
)
hooks["on_cell_input_request"].assert_not_called()
hooks["on_notebook_start"].assert_not_called()
hooks["on_notebook_complete"].assert_not_called()
hooks["on_notebook_error"].assert_not_called()
Expand Down Expand Up @@ -1829,6 +1866,7 @@ def test_async_cell_hooks(self, executor, cell_mock, message_mock):
cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_OK
)
hooks["on_cell_error"].assert_not_called()
hooks["on_cell_input_request"].assert_not_called()
hooks["on_notebook_start"].assert_not_called()
hooks["on_notebook_complete"].assert_not_called()
hooks["on_notebook_error"].assert_not_called()
Expand Down Expand Up @@ -1859,6 +1897,7 @@ def test_error_async_cell_hooks(self, executor, cell_mock, message_mock):
hooks["on_cell_error"].assert_called_once_with(
cell=cell_mock, cell_index=0, execute_reply=EXECUTE_REPLY_ERROR
)
hooks["on_cell_input_request"].assert_not_called()
hooks["on_notebook_start"].assert_not_called()
hooks["on_notebook_complete"].assert_not_called()
hooks["on_notebook_error"].assert_not_called()
Expand Down
Loading