Skip to content

Commit efc967d

Browse files
author
chengcong1
committed
support nbmodel
1 parent 742c241 commit efc967d

File tree

2 files changed

+285
-2
lines changed

2 files changed

+285
-2
lines changed

jupyter_server/gateway/managers.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66

77
import asyncio
88
import datetime
9+
import inspect
910
import json
1011
import os
12+
import time
13+
import typing as t
14+
from operator import is_
1115
from queue import Empty, Queue
1216
from threading import Thread
1317
from time import monotonic
@@ -642,6 +646,8 @@ async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
642646

643647
def send(self, msg: dict[str, Any]) -> None:
644648
"""Send a message to the queue."""
649+
if "channel" not in msg:
650+
msg["channel"] = self.channel_name
645651
message = json.dumps(msg, default=ChannelQueue.serialize_datetime).replace("</", "<\\/")
646652
self.log.debug(
647653
"Sending message on channel: %s, msg_id: %s, msg_type: %s",
@@ -683,6 +689,9 @@ def is_alive(self) -> bool:
683689
"""Whether the queue is alive."""
684690
return self.channel_socket is not None
685691

692+
async def msg_ready(self) -> bool:
693+
return not self.empty()
694+
686695

687696
class HBChannelQueue(ChannelQueue):
688697
"""A queue for the heartbeat channel."""
@@ -877,5 +886,185 @@ def _route_responses(self):
877886

878887
self.log.debug("Response router thread exiting...")
879888

889+
async def _maybe_awaitable(self, func_result):
890+
"""Helper to handle potentially awaitable results"""
891+
if inspect.isawaitable(func_result):
892+
await func_result
893+
894+
async def _handle_iopub_stdin_messages(
895+
self,
896+
msg_id: str,
897+
output_hook: t.Callable,
898+
stdin_hook: t.Callable,
899+
timeout: t.Optional[float],
900+
allow_stdin: bool,
901+
start_time: float,
902+
) -> None:
903+
"""Handle IOPub messages until idle state"""
904+
while True:
905+
# Calculate remaining timeout
906+
if timeout is not None:
907+
elapsed = time.monotonic() - start_time
908+
remaining = max(0, timeout - elapsed)
909+
if remaining <= 0:
910+
raise TimeoutError("Timeout in IOPub handling")
911+
else:
912+
remaining = None
913+
await self._handle_stdin_messages(stdin_hook, allow_stdin)
914+
try:
915+
msg = await self.iopub_channel.get_msg(timeout=remaining)
916+
except Exception as e:
917+
self.log.warning(f"err ({e})")
918+
919+
if msg["parent_header"].get("msg_id") != msg_id:
920+
continue
921+
922+
await self._maybe_awaitable(output_hook(msg))
923+
924+
if (
925+
msg["header"]["msg_type"] == "status"
926+
and msg["content"].get("execution_state") == "idle"
927+
):
928+
break
929+
930+
async def _handle_stdin_messages(
931+
self,
932+
stdin_hook: t.Callable,
933+
allow_stdin: bool,
934+
) -> None:
935+
"""Handle stdin messages until iopub is idle"""
936+
if not allow_stdin:
937+
return
938+
try:
939+
msg = await self.stdin_channel.get_msg(timeout=0.01)
940+
self.log.info(f"stdin msg: {msg},{type(msg)}")
941+
await self._maybe_awaitable(stdin_hook(msg))
942+
except (Empty, TimeoutError):
943+
pass
944+
except Exception:
945+
self.log.warning("Error handling stdin message", exc_info=True)
946+
947+
async def _wait_for_execution_reply(
948+
self, msg_id: str, timeout: t.Optional[float], start_time: float
949+
) -> dict[str, t.Any]:
950+
"""Wait for execution reply from shell or control channel"""
951+
# Calculate remaining timeout
952+
if timeout is not None:
953+
elapsed = time.monotonic() - start_time
954+
remaining_timeout = max(0, timeout - elapsed)
955+
if remaining_timeout <= 0:
956+
raise TimeoutError("Timeout waiting for reply")
957+
else:
958+
remaining_timeout = None
959+
960+
deadline = time.monotonic() + remaining_timeout if remaining_timeout else None
961+
962+
while True:
963+
if deadline:
964+
remaining = max(0, deadline - time.monotonic())
965+
if remaining <= 0:
966+
raise TimeoutError("Timeout waiting for reply")
967+
else:
968+
remaining = None
969+
970+
# Listen to both shell and control channels
971+
reply_task = asyncio.create_task(self.shell_channel.get_msg(timeout=remaining))
972+
control_task = asyncio.create_task(self.control_channel.get_msg(timeout=remaining))
973+
974+
try:
975+
done, pending = await asyncio.wait(
976+
[reply_task, control_task],
977+
timeout=remaining,
978+
return_when=asyncio.FIRST_COMPLETED,
979+
)
980+
981+
# Cancel pending tasks
982+
for task in pending:
983+
task.cancel()
984+
try:
985+
await task
986+
except asyncio.CancelledError:
987+
pass
988+
989+
if not done:
990+
raise TimeoutError("Timeout waiting for reply")
991+
992+
for task in done:
993+
try:
994+
msg = task.result()
995+
if msg["parent_header"].get("msg_id") == msg_id:
996+
return msg
997+
except Exception:
998+
continue
999+
1000+
except asyncio.TimeoutError as err:
1001+
reply_task.cancel()
1002+
control_task.cancel()
1003+
raise TimeoutError("Timeout waiting for reply") from err
1004+
1005+
async def execute_interactive(
1006+
self,
1007+
code: str,
1008+
silent: bool = False,
1009+
store_history: bool = True,
1010+
user_expressions: t.Optional[dict[str, t.Any]] = None,
1011+
allow_stdin: t.Optional[bool] = None,
1012+
stop_on_error: bool = True,
1013+
timeout: t.Optional[float] = None,
1014+
output_hook: t.Optional[t.Callable[[dict], t.Any]] = None,
1015+
stdin_hook: t.Optional[t.Callable[[dict], t.Any]] = None,
1016+
) -> dict[str, t.Any]:
1017+
"""Execute code in the kernel interactively via gateway"""
1018+
1019+
# Channel alive checks
1020+
if not self.iopub_channel.is_alive():
1021+
raise RuntimeError("IOPub channel must be running to receive output")
1022+
1023+
# Prepare defaults
1024+
if allow_stdin is None:
1025+
allow_stdin = self.allow_stdin
1026+
1027+
if output_hook is None:
1028+
output_hook = self._output_hook_default
1029+
if stdin_hook is None:
1030+
stdin_hook = self._stdin_hook_default
1031+
1032+
# Execute the code
1033+
msg_id = self.execute(
1034+
code=code,
1035+
silent=silent,
1036+
store_history=store_history,
1037+
user_expressions=user_expressions,
1038+
allow_stdin=allow_stdin,
1039+
stop_on_error=stop_on_error,
1040+
)
1041+
1042+
# Setup coordination
1043+
start_time = time.monotonic()
1044+
1045+
try:
1046+
# Handle IOPub messages until idle
1047+
iopub_task = asyncio.create_task(
1048+
self._handle_iopub_stdin_messages(
1049+
msg_id, output_hook, stdin_hook, timeout, allow_stdin, start_time
1050+
),
1051+
name="handle_iopub_stdin_messages",
1052+
)
1053+
await iopub_task
1054+
# Get the execution reply
1055+
reply = await self._wait_for_execution_reply(msg_id, timeout, start_time)
1056+
return reply
1057+
1058+
except asyncio.CancelledError:
1059+
raise
1060+
except TimeoutError:
1061+
raise
1062+
except Exception as e:
1063+
self.log.error(
1064+
f"Error during interactive execution: {e}, msg_id: {msg_id}",
1065+
exc_info=True,
1066+
)
1067+
raise RuntimeError(f"Error in interactive execution: {e}") from e
1068+
8801069

8811070
KernelClientABC.register(GatewayKernelClient)

tests/test_gateway.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,16 @@
2525
from traitlets.config import Config
2626

2727
from jupyter_server.gateway.connections import GatewayWebSocketConnection
28-
from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase, NoOpTokenRenewer
29-
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient, GatewayKernelManager
28+
from jupyter_server.gateway.gateway_client import (
29+
GatewayTokenRenewerBase,
30+
NoOpTokenRenewer,
31+
)
32+
from jupyter_server.gateway.managers import (
33+
ChannelQueue,
34+
GatewayClient,
35+
GatewayKernelClient,
36+
GatewayKernelManager,
37+
)
3038
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler
3139

3240
from .utils import expected_http_error
@@ -902,3 +910,89 @@ async def delete_kernel(jp_fetch, kernel_id):
902910
r = await jp_fetch("api", "kernels", kernel_id, method="DELETE")
903911
assert r.code == 204
904912
assert r.reason == "No Content"
913+
914+
915+
@pytest.fixture
916+
def mock_channel_queue():
917+
queue = ChannelQueue("shell", MagicMock(), MagicMock())
918+
return queue
919+
920+
921+
@pytest.fixture
922+
def gateway_kernel_client(init_gateway, monkeypatch):
923+
client = GatewayKernelClient("fake-kernel-id")
924+
client._channel_queues = {
925+
"shell": ChannelQueue("shell", MagicMock(), MagicMock()),
926+
"iopub": ChannelQueue("iopub", MagicMock(), MagicMock()),
927+
"stdin": ChannelQueue("stdin", MagicMock(), MagicMock()),
928+
"hb": ChannelQueue("hb", MagicMock(), MagicMock()),
929+
"control": ChannelQueue("control", MagicMock(), MagicMock()),
930+
}
931+
client._shell_channel = client._channel_queues["shell"]
932+
client._iopub_channel = client._channel_queues["iopub"]
933+
client._stdin_channel = client._channel_queues["stdin"]
934+
client._hb_channel = client._channel_queues["hb"]
935+
client._control_channel = client._channel_queues["control"]
936+
return client
937+
938+
939+
def fake_create_connection(*args, **kwargs):
940+
return MagicMock()
941+
942+
943+
async def test_gateway_kernel_client_start_and_stop_channels(gateway_kernel_client, monkeypatch):
944+
monkeypatch.setattr("websocket.create_connection", fake_create_connection)
945+
monkeypatch.setattr(gateway_kernel_client, "channel_socket", MagicMock())
946+
monkeypatch.setattr(gateway_kernel_client, "response_router", MagicMock())
947+
await gateway_kernel_client.start_channels()
948+
gateway_kernel_client.stop_channels()
949+
assert gateway_kernel_client._channels_stopped
950+
951+
952+
# @pytest.mark.asyncio
953+
async def test_gateway_kernel_client_execute_interactive(gateway_kernel_client, monkeypatch):
954+
gateway_kernel_client.execute = MagicMock(return_value="msg-123")
955+
956+
async def fake_shell_get_msg(timeout=None):
957+
return {"parent_header": {"msg_id": "msg-123"}, "msg_type": "execute_reply"}
958+
959+
gateway_kernel_client.shell_channel.get_msg = fake_shell_get_msg
960+
961+
async def fake_iopub_get_msg(timeout=None):
962+
await asyncio.sleep(0.01)
963+
return {
964+
"parent_header": {"msg_id": "msg-123"},
965+
"msg_type": "status",
966+
"header": {"msg_type": "status"},
967+
"content": {"execution_state": "idle"},
968+
}
969+
970+
gateway_kernel_client.iopub_channel.get_msg = fake_iopub_get_msg
971+
972+
async def fake_stdin_get_msg(timeout=None):
973+
await asyncio.sleep(0.01)
974+
return {"parent_header": {"msg_id": "msg-123"}, "msg_type": "input_request"}
975+
976+
gateway_kernel_client.stdin_channel.get_msg = fake_stdin_get_msg
977+
output_msgs = []
978+
979+
async def output_hook(msg):
980+
output_msgs.append(msg)
981+
982+
stdin_msgs = []
983+
984+
async def stdin_hook(msg):
985+
stdin_msgs.append(msg)
986+
987+
reply = await gateway_kernel_client.execute_interactive(
988+
"print(1)", output_hook=output_hook, stdin_hook=stdin_hook
989+
)
990+
assert reply["msg_type"] == "execute_reply"
991+
992+
993+
async def test_gateway_channel_queue_get_msg_with_response_router_finished(
994+
mock_channel_queue,
995+
):
996+
mock_channel_queue.response_router_finished = True
997+
with pytest.raises(RuntimeError):
998+
await mock_channel_queue.get_msg(timeout=0.1)

0 commit comments

Comments
 (0)