Skip to content

Commit 924d276

Browse files
author
chengcong1
committed
support nbmodel
remove unused import fix type check
1 parent 04dd3e7 commit 924d276

File tree

2 files changed

+288
-2
lines changed

2 files changed

+288
-2
lines changed

jupyter_server/gateway/managers.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66

77
import asyncio
88
import datetime
9+
import inspect
910
import json
1011
import os
12+
import time
13+
import typing as t
1114
from queue import Empty, Queue
1215
from threading import Thread
1316
from time import monotonic
17+
from turtle import st
18+
from types import CoroutineType, coroutine
1419
from typing import TYPE_CHECKING, Any, Optional, cast
1520

1621
import websocket
@@ -642,6 +647,8 @@ async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
642647

643648
def send(self, msg: dict[str, Any]) -> None:
644649
"""Send a message to the queue."""
650+
if "channel" not in msg:
651+
msg["channel"] = self.channel_name
645652
message = json.dumps(msg, default=ChannelQueue.serialize_datetime).replace("</", "<\\/")
646653
self.log.debug(
647654
"Sending message on channel: %s, msg_id: %s, msg_type: %s",
@@ -683,6 +690,9 @@ def is_alive(self) -> bool:
683690
"""Whether the queue is alive."""
684691
return self.channel_socket is not None
685692

693+
async def msg_ready(self) -> bool:
694+
return not self.empty()
695+
686696

687697
class HBChannelQueue(ChannelQueue):
688698
"""A queue for the heartbeat channel."""
@@ -877,5 +887,187 @@ def _route_responses(self):
877887

878888
self.log.debug("Response router thread exiting...")
879889

890+
async def _maybe_awaitable(self, func_result):
891+
"""Helper to handle potentially awaitable results"""
892+
if inspect.isawaitable(func_result):
893+
await func_result
894+
895+
async def _handle_iopub_stdin_messages(
896+
self,
897+
msg_id: str,
898+
output_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]],
899+
stdin_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]],
900+
timeout: t.Optional[float],
901+
allow_stdin: bool,
902+
start_time: float,
903+
) -> None:
904+
"""Handle IOPub messages until idle state"""
905+
while True:
906+
# Calculate remaining timeout
907+
if timeout is not None:
908+
elapsed = time.monotonic() - start_time
909+
remaining = max(0, timeout - elapsed)
910+
if remaining <= 0:
911+
raise TimeoutError("Timeout in IOPub handling")
912+
else:
913+
remaining = None
914+
if stdin_hook is not None and allow_stdin:
915+
await self._handle_stdin_messages(stdin_hook, allow_stdin)
916+
try:
917+
msg = await self.iopub_channel.get_msg(timeout=remaining)
918+
except Exception as e:
919+
self.log.warning(f"err ({e})")
920+
921+
if msg["parent_header"].get("msg_id") != msg_id:
922+
continue
923+
924+
if output_hook is not None:
925+
await self._maybe_awaitable(output_hook(msg))
926+
927+
if (
928+
msg["header"]["msg_type"] == "status"
929+
and msg["content"].get("execution_state") == "idle"
930+
):
931+
break
932+
933+
async def _handle_stdin_messages(
934+
self,
935+
stdin_hook: t.Callable[[dict[str, t.Any]], t.Any],
936+
allow_stdin: bool,
937+
) -> None:
938+
"""Handle stdin messages until iopub is idle"""
939+
if not allow_stdin:
940+
return
941+
try:
942+
msg = await self.stdin_channel.get_msg(timeout=0.01)
943+
self.log.info(f"stdin msg: {msg},{type(msg)}")
944+
await self._maybe_awaitable(stdin_hook(msg))
945+
except (Empty, TimeoutError):
946+
pass
947+
except Exception:
948+
self.log.warning("Error handling stdin message", exc_info=True)
949+
950+
async def _wait_for_execution_reply(
951+
self, msg_id: str, timeout: t.Optional[float], start_time: float
952+
) -> dict[str, t.Any]:
953+
"""Wait for execution reply from shell or control channel"""
954+
# Calculate remaining timeout
955+
if timeout is not None:
956+
elapsed = time.monotonic() - start_time
957+
remaining_timeout = max(0, timeout - elapsed)
958+
if remaining_timeout <= 0:
959+
raise TimeoutError("Timeout waiting for reply")
960+
else:
961+
remaining_timeout = None
962+
963+
deadline = time.monotonic() + remaining_timeout if remaining_timeout else None
964+
965+
while True:
966+
if deadline:
967+
remaining = max(0, deadline - time.monotonic())
968+
if remaining <= 0:
969+
raise TimeoutError("Timeout waiting for reply")
970+
else:
971+
remaining = None
972+
973+
# Listen to both shell and control channels
974+
reply_task = asyncio.create_task(self.shell_channel.get_msg(timeout=remaining))
975+
control_task = asyncio.create_task(self.control_channel.get_msg(timeout=remaining))
976+
977+
try:
978+
done, pending = await asyncio.wait(
979+
[reply_task, control_task],
980+
timeout=remaining,
981+
return_when=asyncio.FIRST_COMPLETED,
982+
)
983+
984+
# Cancel pending tasks
985+
for task in pending:
986+
task.cancel()
987+
try:
988+
await task
989+
except asyncio.CancelledError:
990+
pass
991+
992+
if not done:
993+
raise TimeoutError("Timeout waiting for reply")
994+
995+
for task in done:
996+
try:
997+
msg: dict[str, t.Any] = task.result()
998+
if msg["parent_header"].get("msg_id") == msg_id:
999+
return msg
1000+
except Exception:
1001+
continue
1002+
1003+
except asyncio.TimeoutError as err:
1004+
reply_task.cancel()
1005+
control_task.cancel()
1006+
raise TimeoutError("Timeout waiting for reply") from err
1007+
1008+
async def execute_interactive(
1009+
self,
1010+
code: str,
1011+
silent: bool = False,
1012+
store_history: bool = True,
1013+
user_expressions: t.Optional[dict[str, t.Any]] = None,
1014+
allow_stdin: t.Optional[bool] = None,
1015+
stop_on_error: bool = True,
1016+
timeout: t.Optional[float] = None,
1017+
output_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]] = None,
1018+
stdin_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]] = None,
1019+
) -> dict[str, t.Any]: # type: ignore[override] # Reason: base class sets `execute_interactive` via assignment, so mypy cannot infer override compatibility
1020+
"""Execute code in the kernel interactively via gateway"""
1021+
1022+
# Channel alive checks
1023+
if not self.iopub_channel.is_alive():
1024+
raise RuntimeError("IOPub channel must be running to receive output")
1025+
1026+
# Prepare defaults
1027+
if allow_stdin is None:
1028+
allow_stdin = self.allow_stdin
1029+
1030+
if output_hook is None:
1031+
output_hook = self._output_hook_default
1032+
if stdin_hook is None:
1033+
stdin_hook = self._stdin_hook_default
1034+
1035+
# Execute the code
1036+
msg_id = self.execute(
1037+
code=code,
1038+
silent=silent,
1039+
store_history=store_history,
1040+
user_expressions=user_expressions,
1041+
allow_stdin=allow_stdin,
1042+
stop_on_error=stop_on_error,
1043+
)
1044+
1045+
# Setup coordination
1046+
start_time = time.monotonic()
1047+
1048+
try:
1049+
# Handle IOPub messages until idle
1050+
iopub_task = asyncio.create_task(
1051+
self._handle_iopub_stdin_messages(
1052+
msg_id, output_hook, stdin_hook, timeout, allow_stdin, start_time
1053+
),
1054+
name="handle_iopub_stdin_messages",
1055+
)
1056+
await iopub_task
1057+
# Get the execution reply
1058+
reply = await self._wait_for_execution_reply(msg_id, timeout, start_time)
1059+
return reply
1060+
1061+
except asyncio.CancelledError:
1062+
raise
1063+
except TimeoutError:
1064+
raise
1065+
except Exception as e:
1066+
self.log.error(
1067+
f"Error during interactive execution: {e}, msg_id: {msg_id}",
1068+
exc_info=True,
1069+
)
1070+
raise RuntimeError(f"Error in interactive execution: {e}") from e
1071+
8801072

8811073
KernelClientABC.register(GatewayKernelClient)

tests/test_gateway.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,16 @@
2525
from traitlets.config import Config
2626

2727
from jupyter_server.gateway.connections import GatewayWebSocketConnection
28-
from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase, NoOpTokenRenewer
29-
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient, GatewayKernelManager
28+
from jupyter_server.gateway.gateway_client import (
29+
GatewayTokenRenewerBase,
30+
NoOpTokenRenewer,
31+
)
32+
from jupyter_server.gateway.managers import (
33+
ChannelQueue,
34+
GatewayClient,
35+
GatewayKernelClient,
36+
GatewayKernelManager,
37+
)
3038
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler
3139

3240
from .utils import expected_http_error
@@ -902,3 +910,89 @@ async def delete_kernel(jp_fetch, kernel_id):
902910
r = await jp_fetch("api", "kernels", kernel_id, method="DELETE")
903911
assert r.code == 204
904912
assert r.reason == "No Content"
913+
914+
915+
@pytest.fixture
916+
def mock_channel_queue():
917+
queue = ChannelQueue("shell", MagicMock(), MagicMock())
918+
return queue
919+
920+
921+
@pytest.fixture
922+
def gateway_kernel_client(init_gateway, monkeypatch):
923+
client = GatewayKernelClient("fake-kernel-id")
924+
client._channel_queues = {
925+
"shell": ChannelQueue("shell", MagicMock(), MagicMock()),
926+
"iopub": ChannelQueue("iopub", MagicMock(), MagicMock()),
927+
"stdin": ChannelQueue("stdin", MagicMock(), MagicMock()),
928+
"hb": ChannelQueue("hb", MagicMock(), MagicMock()),
929+
"control": ChannelQueue("control", MagicMock(), MagicMock()),
930+
}
931+
client._shell_channel = client._channel_queues["shell"]
932+
client._iopub_channel = client._channel_queues["iopub"]
933+
client._stdin_channel = client._channel_queues["stdin"]
934+
client._hb_channel = client._channel_queues["hb"]
935+
client._control_channel = client._channel_queues["control"]
936+
return client
937+
938+
939+
def fake_create_connection(*args, **kwargs):
940+
return MagicMock()
941+
942+
943+
async def test_gateway_kernel_client_start_and_stop_channels(gateway_kernel_client, monkeypatch):
944+
monkeypatch.setattr("websocket.create_connection", fake_create_connection)
945+
monkeypatch.setattr(gateway_kernel_client, "channel_socket", MagicMock())
946+
monkeypatch.setattr(gateway_kernel_client, "response_router", MagicMock())
947+
await gateway_kernel_client.start_channels()
948+
gateway_kernel_client.stop_channels()
949+
assert gateway_kernel_client._channels_stopped
950+
951+
952+
# @pytest.mark.asyncio
953+
async def test_gateway_kernel_client_execute_interactive(gateway_kernel_client, monkeypatch):
954+
gateway_kernel_client.execute = MagicMock(return_value="msg-123")
955+
956+
async def fake_shell_get_msg(timeout=None):
957+
return {"parent_header": {"msg_id": "msg-123"}, "msg_type": "execute_reply"}
958+
959+
gateway_kernel_client.shell_channel.get_msg = fake_shell_get_msg
960+
961+
async def fake_iopub_get_msg(timeout=None):
962+
await asyncio.sleep(0.01)
963+
return {
964+
"parent_header": {"msg_id": "msg-123"},
965+
"msg_type": "status",
966+
"header": {"msg_type": "status"},
967+
"content": {"execution_state": "idle"},
968+
}
969+
970+
gateway_kernel_client.iopub_channel.get_msg = fake_iopub_get_msg
971+
972+
async def fake_stdin_get_msg(timeout=None):
973+
await asyncio.sleep(0.01)
974+
return {"parent_header": {"msg_id": "msg-123"}, "msg_type": "input_request"}
975+
976+
gateway_kernel_client.stdin_channel.get_msg = fake_stdin_get_msg
977+
output_msgs = []
978+
979+
async def output_hook(msg):
980+
output_msgs.append(msg)
981+
982+
stdin_msgs = []
983+
984+
async def stdin_hook(msg):
985+
stdin_msgs.append(msg)
986+
987+
reply = await gateway_kernel_client.execute_interactive(
988+
"print(1)", output_hook=output_hook, stdin_hook=stdin_hook
989+
)
990+
assert reply["msg_type"] == "execute_reply"
991+
992+
993+
async def test_gateway_channel_queue_get_msg_with_response_router_finished(
994+
mock_channel_queue,
995+
):
996+
mock_channel_queue.response_router_finished = True
997+
with pytest.raises(RuntimeError):
998+
await mock_channel_queue.get_msg(timeout=0.1)

0 commit comments

Comments
 (0)