Skip to content

Commit 1fd068f

Browse files
committed
Support stdin in server-side execution
1 parent bf1b0df commit 1fd068f

File tree

2 files changed

+88
-5
lines changed

2 files changed

+88
-5
lines changed

plugins/kernels/fps_kernels/kernel_driver/driver.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import os
33
import time
44
import uuid
5+
from functools import partial
56
from typing import Any, Dict, List, Optional, cast
67

7-
from pycrdt import Array, Map
8+
from pycrdt import Array, Map, Text
89

910
from jupyverse_api.yjs import Yjs
1011

@@ -46,6 +47,7 @@ def __init__(
4647
self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {}
4748
self.comm_messages: asyncio.Queue = asyncio.Queue()
4849
self.tasks: List[asyncio.Task] = []
50+
self._background_tasks: set[asyncio.Task] = set()
4951

5052
async def restart(self, startup_timeout: float = float("inf")) -> None:
5153
for task in self.tasks:
@@ -80,13 +82,23 @@ async def connect(self, startup_timeout: float = float("inf")) -> None:
8082

8183
def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
8284
connection_cfg = connection_cfg or self.connection_cfg
83-
self.shell_channel = connect_channel("shell", connection_cfg)
85+
self.shell_channel = connect_channel(
86+
"shell",
87+
connection_cfg,
88+
identity=self.session_id.encode(),
89+
)
8490
self.control_channel = connect_channel("control", connection_cfg)
8591
self.iopub_channel = connect_channel("iopub", connection_cfg)
92+
self.stdin_channel = connect_channel(
93+
"stdin",
94+
connection_cfg,
95+
identity=self.session_id.encode(),
96+
)
8697

8798
def listen_channels(self):
8899
self.tasks.append(asyncio.create_task(self.listen_iopub()))
89100
self.tasks.append(asyncio.create_task(self.listen_shell()))
101+
self.tasks.append(asyncio.create_task(self.listen_stdin()))
90102

91103
async def stop(self) -> None:
92104
self.kernel_process.kill()
@@ -111,6 +123,13 @@ async def listen_shell(self):
111123
if msg_id in self.execute_requests.keys():
112124
self.execute_requests[msg_id]["shell_msg"].put_nowait(msg)
113125

126+
async def listen_stdin(self):
127+
while True:
128+
msg = await receive_message(self.stdin_channel, change_str_to_date=True)
129+
msg_id = msg["parent_header"].get("msg_id")
130+
if msg_id in self.execute_requests.keys():
131+
self.execute_requests[msg_id]["stdin_msg"].put_nowait(msg)
132+
114133
async def execute(
115134
self,
116135
ycell: Map,
@@ -121,7 +140,7 @@ async def execute(
121140
if ycell["cell_type"] != "code":
122141
return
123142
ycell["execution_state"] = "busy"
124-
content = {"code": str(ycell["source"]), "silent": False}
143+
content = {"code": str(ycell["source"]), "silent": False, "allow_stdin": True}
125144
msg = create_message(
126145
"execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt)
127146
)
@@ -134,6 +153,7 @@ async def execute(
134153
self.execute_requests[msg_id] = {
135154
"iopub_msg": asyncio.Queue(),
136155
"shell_msg": asyncio.Queue(),
156+
"stdin_msg": asyncio.Queue(),
137157
}
138158
if wait_for_executed:
139159
deadline = time.time() + timeout
@@ -165,21 +185,75 @@ async def execute(
165185
ycell["execution_state"] = "idle"
166186
del self.execute_requests[msg_id]
167187
else:
168-
self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell)))
188+
stdin_task = asyncio.create_task(self._handle_stdin(msg_id, ycell))
189+
self.tasks.append(stdin_task)
190+
self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell, stdin_task)))
169191

170-
async def _handle_iopub(self, msg_id: str, ycell: Map) -> None:
192+
async def _handle_iopub(self, msg_id: str, ycell: Map, stdin_task: asyncio.Task) -> None:
171193
while True:
172194
msg = await self.execute_requests[msg_id]["iopub_msg"].get()
173195
await self._handle_outputs(ycell["outputs"], msg)
174196
if (
175197
(msg["header"]["msg_type"] == "status"
176198
and msg["content"]["execution_state"] == "idle")
177199
):
200+
stdin_task.cancel()
178201
msg = await self.execute_requests[msg_id]["shell_msg"].get()
179202
with ycell.doc.transaction():
180203
ycell["execution_count"] = msg["content"]["execution_count"]
181204
ycell["execution_state"] = "idle"
182205

206+
async def _handle_stdin(self, msg_id: str, ycell: Map) -> None:
207+
while True:
208+
msg = await self.execute_requests[msg_id]["stdin_msg"].get()
209+
if msg["msg_type"] == "input_request":
210+
content = msg["content"]
211+
prompt = content["prompt"]
212+
password = content["password"]
213+
stdin_output = Map(
214+
{
215+
"output_type": "stdin",
216+
"submitted": False,
217+
"password": password,
218+
"prompt": prompt,
219+
"value": Text(),
220+
}
221+
)
222+
outputs: Array = cast(Array, ycell.get("outputs"))
223+
stdin_idx = len(outputs)
224+
outputs.append(stdin_output)
225+
stdin_output.observe(
226+
partial(self._handle_stdin_submission, outputs, stdin_idx, password, prompt)
227+
)
228+
229+
def _handle_stdin_submission(self, outputs, stdin_idx, password, prompt, event):
230+
if event.target["submitted"]:
231+
# send input reply to kernel
232+
value = str(event.target["value"])
233+
content = {"value": value}
234+
msg = create_message(
235+
"input_reply", content, session_id=self.session_id, msg_id=str(self.msg_cnt)
236+
)
237+
task0 = asyncio.create_task(
238+
send_message(msg, self.stdin_channel, self.key, change_date_to_str=True)
239+
)
240+
if password:
241+
value = "········"
242+
value = f"{prompt} {value}"
243+
task1 = asyncio.create_task(self._change_stdin_to_stream(outputs, stdin_idx, value))
244+
self._background_tasks.add(task0)
245+
self._background_tasks.add(task1)
246+
task0.add_done_callback(self._background_tasks.discard)
247+
task1.add_done_callback(self._background_tasks.discard)
248+
249+
async def _change_stdin_to_stream(self, outputs, stdin_idx, value):
250+
# replace stdin output with stream output
251+
outputs[stdin_idx] = {
252+
"output_type": "stream",
253+
"name": "stdin",
254+
"text": value + '\n',
255+
}
256+
183257
async def _handle_comms(self) -> None:
184258
if self.yjs is None or self.yjs.widgets is None: # type: ignore
185259
return

plugins/yjs/fps_yjs/ydocs/ynotebook.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ def get_cell(self, index: int) -> Dict[str, Any]:
4949
and not cell["attachments"]
5050
):
5151
del cell["attachments"]
52+
outputs = cell.get("outputs", [])
53+
del_outputs = []
54+
for idx, output in enumerate(outputs):
55+
if output["output_type"] == "stdin":
56+
del_outputs.append(idx)
57+
deleted = 0
58+
for idx in del_outputs:
59+
del outputs[idx - deleted]
60+
deleted += 1
5261
return cell
5362

5463
def append_cell(self, value: Dict[str, Any]) -> None:

0 commit comments

Comments
 (0)