22import os
33import time
44import uuid
5+ from functools import partial
56from typing import Any , Dict , List , Optional , cast
67
7- from pycrdt import Array , Map
8+ from pycrdt import Array , Map , Text
89
910from jupyverse_api .yjs import Yjs
1011
@@ -46,6 +47,7 @@ def __init__(
4647 self .execute_requests : Dict [str , Dict [str , asyncio .Queue ]] = {}
4748 self .comm_messages : asyncio .Queue = asyncio .Queue ()
4849 self .tasks : List [asyncio .Task ] = []
50+ self ._background_tasks : set [asyncio .Task ] = set ()
4951
5052 async def restart (self , startup_timeout : float = float ("inf" )) -> None :
5153 for task in self .tasks :
@@ -80,13 +82,23 @@ async def connect(self, startup_timeout: float = float("inf")) -> None:
8082
8183 def connect_channels (self , connection_cfg : Optional [cfg_t ] = None ):
8284 connection_cfg = connection_cfg or self .connection_cfg
83- self .shell_channel = connect_channel ("shell" , connection_cfg )
85+ self .shell_channel = connect_channel (
86+ "shell" ,
87+ connection_cfg ,
88+ identity = self .session_id .encode (),
89+ )
8490 self .control_channel = connect_channel ("control" , connection_cfg )
8591 self .iopub_channel = connect_channel ("iopub" , connection_cfg )
92+ self .stdin_channel = connect_channel (
93+ "stdin" ,
94+ connection_cfg ,
95+ identity = self .session_id .encode (),
96+ )
8697
8798 def listen_channels (self ):
8899 self .tasks .append (asyncio .create_task (self .listen_iopub ()))
89100 self .tasks .append (asyncio .create_task (self .listen_shell ()))
101+ self .tasks .append (asyncio .create_task (self .listen_stdin ()))
90102
91103 async def stop (self ) -> None :
92104 self .kernel_process .kill ()
@@ -111,6 +123,13 @@ async def listen_shell(self):
111123 if msg_id in self .execute_requests .keys ():
112124 self .execute_requests [msg_id ]["shell_msg" ].put_nowait (msg )
113125
126+ async def listen_stdin (self ):
127+ while True :
128+ msg = await receive_message (self .stdin_channel , change_str_to_date = True )
129+ msg_id = msg ["parent_header" ].get ("msg_id" )
130+ if msg_id in self .execute_requests .keys ():
131+ self .execute_requests [msg_id ]["stdin_msg" ].put_nowait (msg )
132+
114133 async def execute (
115134 self ,
116135 ycell : Map ,
@@ -121,7 +140,7 @@ async def execute(
121140 if ycell ["cell_type" ] != "code" :
122141 return
123142 ycell ["execution_state" ] = "busy"
124- content = {"code" : str (ycell ["source" ]), "silent" : False }
143+ content = {"code" : str (ycell ["source" ]), "silent" : False , "allow_stdin" : True }
125144 msg = create_message (
126145 "execute_request" , content , session_id = self .session_id , msg_id = str (self .msg_cnt )
127146 )
@@ -134,6 +153,7 @@ async def execute(
134153 self .execute_requests [msg_id ] = {
135154 "iopub_msg" : asyncio .Queue (),
136155 "shell_msg" : asyncio .Queue (),
156+ "stdin_msg" : asyncio .Queue (),
137157 }
138158 if wait_for_executed :
139159 deadline = time .time () + timeout
@@ -165,21 +185,73 @@ async def execute(
165185 ycell ["execution_state" ] = "idle"
166186 del self .execute_requests [msg_id ]
167187 else :
168- self .tasks .append (asyncio .create_task (self ._handle_iopub (msg_id , ycell )))
188+ stdin_task = asyncio .create_task (self ._handle_stdin (msg_id , ycell ))
189+ self .tasks .append (stdin_task )
190+ self .tasks .append (asyncio .create_task (self ._handle_iopub (msg_id , ycell , stdin_task )))
169191
170- async def _handle_iopub (self , msg_id : str , ycell : Map ) -> None :
192+ async def _handle_iopub (self , msg_id : str , ycell : Map , stdin_task : asyncio . Task ) -> None :
171193 while True :
172194 msg = await self .execute_requests [msg_id ]["iopub_msg" ].get ()
173195 await self ._handle_outputs (ycell ["outputs" ], msg )
174196 if (
175197 (msg ["header" ]["msg_type" ] == "status"
176198 and msg ["content" ]["execution_state" ] == "idle" )
177199 ):
200+ stdin_task .cancel ()
178201 msg = await self .execute_requests [msg_id ]["shell_msg" ].get ()
179202 with ycell .doc .transaction ():
180203 ycell ["execution_count" ] = msg ["content" ]["execution_count" ]
181204 ycell ["execution_state" ] = "idle"
182205
206+ async def _handle_stdin (self , msg_id : str , ycell : Map ) -> None :
207+ while True :
208+ msg = await self .execute_requests [msg_id ]["stdin_msg" ].get ()
209+ if msg ["msg_type" ] == "input_request" :
210+ content = msg ["content" ]
211+ prompt = content ["prompt" ]
212+ password = content ["password" ]
213+ stdin_output = Map (
214+ {
215+ "output_type" : "stdin" ,
216+ "submitted" : False ,
217+ "password" : password ,
218+ "prompt" : prompt ,
219+ "value" : Text (),
220+ }
221+ )
222+ outputs = ycell .get ("outputs" )
223+ stdin_idx = len (outputs )
224+ outputs .append (stdin_output )
225+ stdin_output .observe (partial (self ._handle_stdin_submission , outputs , stdin_idx , password , prompt ))
226+
227+ def _handle_stdin_submission (self , outputs , stdin_idx , password , prompt , event ):
228+ if event .target ["submitted" ]:
229+ # send input reply to kernel
230+ value = str (event .target ["value" ])
231+ content = {"value" : value }
232+ msg = create_message (
233+ "input_reply" , content , session_id = self .session_id , msg_id = str (self .msg_cnt )
234+ )
235+ task0 = asyncio .create_task (
236+ send_message (msg , self .stdin_channel , self .key , change_date_to_str = True )
237+ )
238+ if password :
239+ value = "········"
240+ value = f"{ prompt } { value } "
241+ task1 = asyncio .create_task (self ._change_stdin_to_stream (outputs , stdin_idx , value ))
242+ self ._background_tasks .add (task0 )
243+ self ._background_tasks .add (task1 )
244+ task0 .add_done_callback (self ._background_tasks .discard )
245+ task1 .add_done_callback (self ._background_tasks .discard )
246+
247+ async def _change_stdin_to_stream (self , outputs , stdin_idx , value ):
248+ # replace stdin output with stream output
249+ outputs [stdin_idx ] = {
250+ "output_type" : "stream" ,
251+ "name" : "stdin" ,
252+ "text" : value + '\n ' ,
253+ }
254+
183255 async def _handle_comms (self ) -> None :
184256 if self .yjs is None or self .yjs .widgets is None : # type: ignore
185257 return
0 commit comments