|
| 1 | +################################################################################ |
| 2 | +# |
| 3 | +# Copyright 2022-2025 Vincent Dary |
| 4 | +# |
| 5 | +# This file is part of fiit. |
| 6 | +# |
| 7 | +# fiit is free software: you can redistribute it and/or modify it under the |
| 8 | +# terms of the GNU General Public License as published by the Free Software |
| 9 | +# Foundation, either version 3 of the License, or (at your option) any later |
| 10 | +# version. |
| 11 | +# |
| 12 | +# fiit is distributed in the hope that it will be useful, but WITHOUT ANY |
| 13 | +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR |
| 14 | +# A PARTICULAR PURPOSE. See the GNU General Public License for more details. |
| 15 | +# |
| 16 | +# You should have received a copy of the GNU General Public License along with |
| 17 | +# fiit. If not, see <https://www.gnu.org/licenses/>. |
| 18 | +# |
| 19 | +################################################################################ |
| 20 | +import uuid |
| 21 | +import dataclasses |
| 22 | +import tempfile |
| 23 | +import signal |
| 24 | +import sys |
| 25 | +import os |
| 26 | +from typing import Optional |
| 27 | +import asyncio |
| 28 | + |
| 29 | +import zmq |
| 30 | + |
| 31 | +from prompt_toolkit.application import get_app_or_none |
| 32 | + |
| 33 | +from jupyter_console.ptshell import ZMQTerminalInteractiveShell, ask_yes_no |
| 34 | +from jupyter_console.app import ZMQTerminalIPythonApp |
| 35 | +from jupyter_client.consoleapp import JupyterConsoleApp |
| 36 | + |
| 37 | +from fiit.plugins.backend import BACKEND_REQ_GET_BACKEND_DATA, BackendRequest |
| 38 | + |
| 39 | + |
| 40 | +class RemoteKernelSync(Exception): |
| 41 | + pass |
| 42 | + |
| 43 | + |
| 44 | +class SynchronizedZmqTerminal(ZMQTerminalInteractiveShell): |
| 45 | + """ |
| 46 | + Extend the ZMQTerminalInteractiveShell |
| 47 | + """ |
| 48 | + include_other_output = True |
| 49 | + |
| 50 | + ECHO_FILTER = ['%emu_start', '%es', '%step', '%s', '%cont', '%c'] |
| 51 | + |
| 52 | + def __init__(self, **kwargs): |
| 53 | + super().__init__(**kwargs) |
| 54 | + |
| 55 | + # cache |
| 56 | + self._iopub_msg_cache: Optional[dict] = None |
| 57 | + |
| 58 | + # Warning: asyncio.Event must be initialized in the event loop |
| 59 | + self._is_running_cell: Optional[asyncio.Event] = None |
| 60 | + self._interact_lock: Optional[asyncio.Event] = None |
| 61 | + self._is_interact_loop_stopped: Optional[asyncio.Event] = None |
| 62 | + |
| 63 | + self._full_echo = os.getenv('FULL_ECHO', False) |
| 64 | + self._other_is_running_cell_with_echo = False |
| 65 | + self._other_is_running_cell = False |
| 66 | + |
| 67 | + def init_kernel_info(self): |
| 68 | + """ Subclassed to print stdout/stderr stream if kernel is busy. """ |
| 69 | + self.client.hb_channel.unpause() |
| 70 | + msg_id = self.client.kernel_info() |
| 71 | + iopub_socket = self.client.iopub_channel.socket |
| 72 | + shell_socket = self.client.shell_channel.socket |
| 73 | + socket_poller = zmq.Poller() |
| 74 | + socket_poller.register(iopub_socket, zmq.POLLIN) |
| 75 | + socket_poller.register(shell_socket, zmq.POLLIN) |
| 76 | + |
| 77 | + while True: |
| 78 | + socks = dict(socket_poller.poll()) |
| 79 | + |
| 80 | + if socks.get(shell_socket) == zmq.POLLIN: |
| 81 | + reply = self.client.get_shell_msg() |
| 82 | + if reply['parent_header'].get('msg_id') == msg_id: |
| 83 | + self.kernel_info = reply['content'] |
| 84 | + return |
| 85 | + |
| 86 | + elif socks.get(iopub_socket) == zmq.POLLIN: |
| 87 | + msg = self.client.iopub_channel.get_msg() |
| 88 | + if msg['header']['msg_type'] == 'stream': |
| 89 | + if msg['content']['name'] == "stdout": |
| 90 | + print(msg['content']['text'], end='', flush=True) |
| 91 | + elif msg['content']['name'] == 'stderr': |
| 92 | + print(msg['content']['text'], end='', flush=True) |
| 93 | + |
| 94 | + def _init_events(self): |
| 95 | + self._is_running_cell = asyncio.Event() |
| 96 | + self._is_interact_loop_stopped = asyncio.Event() |
| 97 | + self._interact_lock = asyncio.Event() |
| 98 | + self._is_running_cell.clear() |
| 99 | + self._is_interact_loop_stopped.clear() |
| 100 | + self._interact_lock.clear() |
| 101 | + |
| 102 | + @property |
| 103 | + def interact_loop_is_locked(self) -> bool: |
| 104 | + return self._is_interact_loop_stopped.is_set() |
| 105 | + |
| 106 | + @property |
| 107 | + def cell_is_running(self) -> bool: |
| 108 | + return self._is_running_cell.is_set() |
| 109 | + |
| 110 | + async def lock_interact_loop(self): |
| 111 | + if not self.interact_loop_is_locked and not self.cell_is_running: |
| 112 | + app = get_app_or_none() |
| 113 | + if not app.is_done and app.is_running: |
| 114 | + self._interact_lock.clear() |
| 115 | + app.exit(exception=RemoteKernelSync('kernel sync')) |
| 116 | + await self._is_interact_loop_stopped.wait() |
| 117 | + |
| 118 | + def unlock_interact_loop(self): |
| 119 | + if self.interact_loop_is_locked: |
| 120 | + self._interact_lock.set() |
| 121 | + |
| 122 | + async def interact(self, loop=None, display_banner=None): |
| 123 | + """ Override to allow prompt freezing via `RemoteKernelSync`. """ |
| 124 | + while self.keep_running: |
| 125 | + print('\n', end='') |
| 126 | + |
| 127 | + try: |
| 128 | + code = await self.prompt_for_code() |
| 129 | + except EOFError: |
| 130 | + if (not self.confirm_exit |
| 131 | + or ask_yes_no('Do you really want to exit ([y]/n)?', |
| 132 | + 'y', 'n')): |
| 133 | + self.ask_exit() |
| 134 | + except RemoteKernelSync: |
| 135 | + # Can fix ghost side effects of asynchronous prompt not yet exited |
| 136 | + # await asyncio.sleep(0.1) |
| 137 | + self._is_interact_loop_stopped.set() |
| 138 | + await self._interact_lock.wait() |
| 139 | + self._is_interact_loop_stopped.clear() |
| 140 | + |
| 141 | + else: |
| 142 | + if code: |
| 143 | + self._is_running_cell.set() |
| 144 | + self.run_cell(code, store_history=True) |
| 145 | + self._is_running_cell.clear() |
| 146 | + |
| 147 | + async def handle_external_iopub(self, loop=None): |
| 148 | + """ |
| 149 | + Override to fix inefficient and slow manual polling in parent method, |
| 150 | + and allow post jupyter message render with asynchronous capability in |
| 151 | + same event loop (for exemple for asynchronous event sync). |
| 152 | + """ |
| 153 | + self._init_events() |
| 154 | + poller = zmq.asyncio.Poller() |
| 155 | + poller.register(self.client.iopub_channel.socket, zmq.POLLIN) |
| 156 | + |
| 157 | + while self.keep_running: |
| 158 | + events = dict(await poller.poll(0.5)) |
| 159 | + |
| 160 | + if self.client.iopub_channel.socket in events: |
| 161 | + self.handle_iopub() |
| 162 | + await self._post_jupyter_message_render(self._iopub_msg_cache) |
| 163 | + |
| 164 | + async def _post_jupyter_message_render(self, msg: dict) -> None: |
| 165 | + msg = self._iopub_msg_cache |
| 166 | + msg_type = msg['header']['msg_type'] |
| 167 | + |
| 168 | + if (msg_type == 'execute_input' |
| 169 | + and self._other_is_running_cell_with_echo): |
| 170 | + await self.lock_interact_loop() |
| 171 | + content = self._iopub_msg_cache['content'] |
| 172 | + ec = content.get('execution_count', |
| 173 | + self.execution_count - 1) |
| 174 | + |
| 175 | + if self._pending_clearoutput: |
| 176 | + print("\r", end="") |
| 177 | + sys.stdout.flush() |
| 178 | + sys.stdout.flush() |
| 179 | + self._pending_clearoutput = False |
| 180 | + |
| 181 | + sys.stdout.write(f'Remote In [{ec}]: {content["code"]}\n') |
| 182 | + sys.stdout.flush() |
| 183 | + |
| 184 | + elif not self._other_is_running_cell_with_echo: |
| 185 | + self.unlock_interact_loop() |
| 186 | + |
| 187 | + def _include_output(self, msg: dict) -> bool: |
| 188 | + self._set_terminal_states(msg) |
| 189 | + msg_type = msg['header']['msg_type'] |
| 190 | + |
| 191 | + if self._other_is_running_cell_with_echo and msg_type == 'execute_input': |
| 192 | + return False # input render from handle_iopub() is bugged for other |
| 193 | + elif self._other_is_running_cell and not self._other_is_running_cell_with_echo: |
| 194 | + return False # no render for other cell running without echo |
| 195 | + |
| 196 | + return super().include_output(msg) |
| 197 | + |
| 198 | + def _msg_cache_wrapper(self, msg: dict) -> bool: |
| 199 | + ret = self._include_output(msg) |
| 200 | + self._iopub_msg_cache = msg |
| 201 | + return ret |
| 202 | + |
| 203 | + def include_output(self, msg: dict) -> bool: |
| 204 | + """ |
| 205 | + `Include_output()` is the best place to capture iopub message since this |
| 206 | + method is called just after read message on the channel iopub channel |
| 207 | + in `handle_iopub()`. |
| 208 | + """ |
| 209 | + return self._msg_cache_wrapper(msg) |
| 210 | + |
| 211 | + def _set_terminal_states(self, msg: dict) -> None: |
| 212 | + """ |
| 213 | + Warning: |
| 214 | + This methods set the states only for this terminal layer before |
| 215 | + `handle_iopub()` set states, so `ZMQTerminalInteractiveShell` states |
| 216 | + are the past states (t-1), minus the `execution_count` counter which is |
| 217 | + synchronized before this method call. |
| 218 | + """ |
| 219 | + msg_type = msg['header']['msg_type'] |
| 220 | + from_here = self.from_here(msg) |
| 221 | + |
| 222 | + if (self.include_other_output |
| 223 | + and not from_here |
| 224 | + and self._execution_state == 'busy' |
| 225 | + and msg_type == 'execute_input' |
| 226 | + and (self._full_echo or msg['content']['code'] in self.ECHO_FILTER)): |
| 227 | + self._other_is_running_cell_with_echo = True |
| 228 | + elif (self.include_other_output |
| 229 | + and not from_here |
| 230 | + and self._execution_state == 'busy' |
| 231 | + and msg_type == 'status' |
| 232 | + and msg['content']['execution_state'] == 'idle' |
| 233 | + and (self._full_echo or self._other_is_running_cell_with_echo)): |
| 234 | + self._other_is_running_cell_with_echo = False |
| 235 | + |
| 236 | + if (self.include_other_output |
| 237 | + and not from_here |
| 238 | + and self._execution_state == 'busy' |
| 239 | + and msg_type == 'execute_input'): |
| 240 | + self._other_is_running_cell = True |
| 241 | + elif (self.include_other_output |
| 242 | + and not from_here |
| 243 | + and self._execution_state == 'busy' |
| 244 | + and msg_type == 'status' |
| 245 | + and msg['content']['execution_state'] == 'idle'): |
| 246 | + self._other_is_running_cell = False |
| 247 | + |
| 248 | + |
| 249 | +class SynchronizedTerminalApp(ZMQTerminalIPythonApp): |
| 250 | + classes = [SynchronizedZmqTerminal] + JupyterConsoleApp.classes |
| 251 | + |
| 252 | + def init_shell(self): |
| 253 | + JupyterConsoleApp.initialize(self) |
| 254 | + # relay sigint to kernel |
| 255 | + signal.signal(signal.SIGINT, self.handle_sigint) |
| 256 | + self.shell = SynchronizedZmqTerminal.instance( |
| 257 | + parent=self, |
| 258 | + manager=self.kernel_manager, |
| 259 | + client=self.kernel_client, |
| 260 | + confirm_exit=self.confirm_exit, |
| 261 | + ) |
| 262 | + self.shell.own_kernel = not self.existing |
| 263 | + |
| 264 | + |
| 265 | +fiit_console = SynchronizedTerminalApp.launch_instance |
| 266 | + |
| 267 | + |
| 268 | +def fiit_console_from_backend(backend_ip: str, backend_port: str) -> None: |
| 269 | + zmq_context = zmq.Context() |
| 270 | + sock = zmq_context.socket(zmq.REQ) |
| 271 | + sock.connect(f'tcp://{backend_ip}:{backend_port}') |
| 272 | + req = BackendRequest(method=BACKEND_REQ_GET_BACKEND_DATA, id=uuid.uuid1().hex) |
| 273 | + sock.send_json(dataclasses.asdict(req)) |
| 274 | + res = sock.recv_json() |
| 275 | + sock.close() |
| 276 | + |
| 277 | + if res.get('error') is not None: |
| 278 | + print(f'error: {res["error"]["message"]}') |
| 279 | + sys.exit(1) |
| 280 | + |
| 281 | + f = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) |
| 282 | + f.write(res['result']['jupyter_client_json_config']) |
| 283 | + f.close() |
| 284 | + print(f'[i] Jupyter console configuration file dropped to "{f.name}".') |
| 285 | + SynchronizedTerminalApp.launch_instance(['--existing', f.name]) |
0 commit comments