|
6 | 6 |
|
7 | 7 | import asyncio |
8 | 8 | import datetime |
| 9 | +import inspect |
9 | 10 | import json |
10 | 11 | import os |
| 12 | +import time |
| 13 | +import typing as t |
| 14 | +from operator import is_ |
11 | 15 | from queue import Empty, Queue |
12 | 16 | from threading import Thread |
13 | 17 | from time import monotonic |
@@ -642,6 +646,8 @@ async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]: |
642 | 646 |
|
643 | 647 | def send(self, msg: dict[str, Any]) -> None: |
644 | 648 | """Send a message to the queue.""" |
| 649 | + if "channel" not in msg: |
| 650 | + msg["channel"] = self.channel_name |
645 | 651 | message = json.dumps(msg, default=ChannelQueue.serialize_datetime).replace("</", "<\\/") |
646 | 652 | self.log.debug( |
647 | 653 | "Sending message on channel: %s, msg_id: %s, msg_type: %s", |
@@ -683,6 +689,9 @@ def is_alive(self) -> bool: |
683 | 689 | """Whether the queue is alive.""" |
684 | 690 | return self.channel_socket is not None |
685 | 691 |
|
| 692 | + async def msg_ready(self) -> bool: |
| 693 | + return not self.empty() |
| 694 | + |
686 | 695 |
|
687 | 696 | class HBChannelQueue(ChannelQueue): |
688 | 697 | """A queue for the heartbeat channel.""" |
@@ -877,5 +886,185 @@ def _route_responses(self): |
877 | 886 |
|
878 | 887 | self.log.debug("Response router thread exiting...") |
879 | 888 |
|
| 889 | + async def _maybe_awaitable(self, func_result): |
| 890 | + """Helper to handle potentially awaitable results""" |
| 891 | + if inspect.isawaitable(func_result): |
| 892 | + await func_result |
| 893 | + |
| 894 | + async def _handle_iopub_stdin_messages( |
| 895 | + self, |
| 896 | + msg_id: str, |
| 897 | + output_hook: t.Callable, |
| 898 | + stdin_hook: t.Callable, |
| 899 | + timeout: t.Optional[float], |
| 900 | + allow_stdin: bool, |
| 901 | + start_time: float, |
| 902 | + ) -> None: |
| 903 | + """Handle IOPub messages until idle state""" |
| 904 | + while True: |
| 905 | + # Calculate remaining timeout |
| 906 | + if timeout is not None: |
| 907 | + elapsed = time.monotonic() - start_time |
| 908 | + remaining = max(0, timeout - elapsed) |
| 909 | + if remaining <= 0: |
| 910 | + raise TimeoutError("Timeout in IOPub handling") |
| 911 | + else: |
| 912 | + remaining = None |
| 913 | + await self._handle_stdin_messages(stdin_hook, allow_stdin) |
| 914 | + try: |
| 915 | + msg = await self.iopub_channel.get_msg(timeout=remaining) |
| 916 | + except Exception as e: |
| 917 | + self.log.warning(f"err ({e})") |
| 918 | + |
| 919 | + if msg["parent_header"].get("msg_id") != msg_id: |
| 920 | + continue |
| 921 | + |
| 922 | + await self._maybe_awaitable(output_hook(msg)) |
| 923 | + |
| 924 | + if ( |
| 925 | + msg["header"]["msg_type"] == "status" |
| 926 | + and msg["content"].get("execution_state") == "idle" |
| 927 | + ): |
| 928 | + break |
| 929 | + |
| 930 | + async def _handle_stdin_messages( |
| 931 | + self, |
| 932 | + stdin_hook: t.Callable, |
| 933 | + allow_stdin: bool, |
| 934 | + ) -> None: |
| 935 | + """Handle stdin messages until iopub is idle""" |
| 936 | + if not allow_stdin: |
| 937 | + return |
| 938 | + try: |
| 939 | + msg = await self.stdin_channel.get_msg(timeout=0.01) |
| 940 | + self.log.info(f"stdin msg: {msg},{type(msg)}") |
| 941 | + await self._maybe_awaitable(stdin_hook(msg)) |
| 942 | + except (Empty, TimeoutError): |
| 943 | + pass |
| 944 | + except Exception: |
| 945 | + self.log.warning("Error handling stdin message", exc_info=True) |
| 946 | + |
| 947 | + async def _wait_for_execution_reply( |
| 948 | + self, msg_id: str, timeout: t.Optional[float], start_time: float |
| 949 | + ) -> dict[str, t.Any]: |
| 950 | + """Wait for execution reply from shell or control channel""" |
| 951 | + # Calculate remaining timeout |
| 952 | + if timeout is not None: |
| 953 | + elapsed = time.monotonic() - start_time |
| 954 | + remaining_timeout = max(0, timeout - elapsed) |
| 955 | + if remaining_timeout <= 0: |
| 956 | + raise TimeoutError("Timeout waiting for reply") |
| 957 | + else: |
| 958 | + remaining_timeout = None |
| 959 | + |
| 960 | + deadline = time.monotonic() + remaining_timeout if remaining_timeout else None |
| 961 | + |
| 962 | + while True: |
| 963 | + if deadline: |
| 964 | + remaining = max(0, deadline - time.monotonic()) |
| 965 | + if remaining <= 0: |
| 966 | + raise TimeoutError("Timeout waiting for reply") |
| 967 | + else: |
| 968 | + remaining = None |
| 969 | + |
| 970 | + # Listen to both shell and control channels |
| 971 | + reply_task = asyncio.create_task(self.shell_channel.get_msg(timeout=remaining)) |
| 972 | + control_task = asyncio.create_task(self.control_channel.get_msg(timeout=remaining)) |
| 973 | + |
| 974 | + try: |
| 975 | + done, pending = await asyncio.wait( |
| 976 | + [reply_task, control_task], |
| 977 | + timeout=remaining, |
| 978 | + return_when=asyncio.FIRST_COMPLETED, |
| 979 | + ) |
| 980 | + |
| 981 | + # Cancel pending tasks |
| 982 | + for task in pending: |
| 983 | + task.cancel() |
| 984 | + try: |
| 985 | + await task |
| 986 | + except asyncio.CancelledError: |
| 987 | + pass |
| 988 | + |
| 989 | + if not done: |
| 990 | + raise TimeoutError("Timeout waiting for reply") |
| 991 | + |
| 992 | + for task in done: |
| 993 | + try: |
| 994 | + msg = task.result() |
| 995 | + if msg["parent_header"].get("msg_id") == msg_id: |
| 996 | + return msg |
| 997 | + except Exception: |
| 998 | + continue |
| 999 | + |
| 1000 | + except asyncio.TimeoutError as err: |
| 1001 | + reply_task.cancel() |
| 1002 | + control_task.cancel() |
| 1003 | + raise TimeoutError("Timeout waiting for reply") from err |
| 1004 | + |
| 1005 | + async def execute_interactive( |
| 1006 | + self, |
| 1007 | + code: str, |
| 1008 | + silent: bool = False, |
| 1009 | + store_history: bool = True, |
| 1010 | + user_expressions: t.Optional[dict[str, t.Any]] = None, |
| 1011 | + allow_stdin: t.Optional[bool] = None, |
| 1012 | + stop_on_error: bool = True, |
| 1013 | + timeout: t.Optional[float] = None, |
| 1014 | + output_hook: t.Optional[t.Callable[[dict], t.Any]] = None, |
| 1015 | + stdin_hook: t.Optional[t.Callable[[dict], t.Any]] = None, |
| 1016 | + ) -> dict[str, t.Any]: |
| 1017 | + """Execute code in the kernel interactively via gateway""" |
| 1018 | + |
| 1019 | + # Channel alive checks |
| 1020 | + if not self.iopub_channel.is_alive(): |
| 1021 | + raise RuntimeError("IOPub channel must be running to receive output") |
| 1022 | + |
| 1023 | + # Prepare defaults |
| 1024 | + if allow_stdin is None: |
| 1025 | + allow_stdin = self.allow_stdin |
| 1026 | + |
| 1027 | + if output_hook is None: |
| 1028 | + output_hook = self._output_hook_default |
| 1029 | + if stdin_hook is None: |
| 1030 | + stdin_hook = self._stdin_hook_default |
| 1031 | + |
| 1032 | + # Execute the code |
| 1033 | + msg_id = self.execute( |
| 1034 | + code=code, |
| 1035 | + silent=silent, |
| 1036 | + store_history=store_history, |
| 1037 | + user_expressions=user_expressions, |
| 1038 | + allow_stdin=allow_stdin, |
| 1039 | + stop_on_error=stop_on_error, |
| 1040 | + ) |
| 1041 | + |
| 1042 | + # Setup coordination |
| 1043 | + start_time = time.monotonic() |
| 1044 | + |
| 1045 | + try: |
| 1046 | + # Handle IOPub messages until idle |
| 1047 | + iopub_task = asyncio.create_task( |
| 1048 | + self._handle_iopub_stdin_messages( |
| 1049 | + msg_id, output_hook, stdin_hook, timeout, allow_stdin, start_time |
| 1050 | + ), |
| 1051 | + name="handle_iopub_stdin_messages", |
| 1052 | + ) |
| 1053 | + await iopub_task |
| 1054 | + # Get the execution reply |
| 1055 | + reply = await self._wait_for_execution_reply(msg_id, timeout, start_time) |
| 1056 | + return reply |
| 1057 | + |
| 1058 | + except asyncio.CancelledError: |
| 1059 | + raise |
| 1060 | + except TimeoutError: |
| 1061 | + raise |
| 1062 | + except Exception as e: |
| 1063 | + self.log.error( |
| 1064 | + f"Error during interactive execution: {e}, msg_id: {msg_id}", |
| 1065 | + exc_info=True, |
| 1066 | + ) |
| 1067 | + raise RuntimeError(f"Error in interactive execution: {e}") from e |
| 1068 | + |
880 | 1069 |
|
881 | 1070 | KernelClientABC.register(GatewayKernelClient) |
0 commit comments