2
2
import os
3
3
import time
4
4
import uuid
5
+ from functools import partial
5
6
from typing import Any , Dict , List , Optional , cast
6
7
7
- from pycrdt import Array , Map
8
+ from pycrdt import Array , Map , Text
8
9
9
10
from jupyverse_api .yjs import Yjs
10
11
@@ -46,6 +47,7 @@ def __init__(
46
47
self .execute_requests : Dict [str , Dict [str , asyncio .Queue ]] = {}
47
48
self .comm_messages : asyncio .Queue = asyncio .Queue ()
48
49
self .tasks : List [asyncio .Task ] = []
50
+ self ._background_tasks : set [asyncio .Task ] = set ()
49
51
50
52
async def restart (self , startup_timeout : float = float ("inf" )) -> None :
51
53
for task in self .tasks :
@@ -80,13 +82,23 @@ async def connect(self, startup_timeout: float = float("inf")) -> None:
80
82
81
83
def connect_channels (self , connection_cfg : Optional [cfg_t ] = None ):
82
84
connection_cfg = connection_cfg or self .connection_cfg
83
- self .shell_channel = connect_channel ("shell" , connection_cfg )
85
+ self .shell_channel = connect_channel (
86
+ "shell" ,
87
+ connection_cfg ,
88
+ identity = self .session_id .encode (),
89
+ )
84
90
self .control_channel = connect_channel ("control" , connection_cfg )
85
91
self .iopub_channel = connect_channel ("iopub" , connection_cfg )
92
+ self .stdin_channel = connect_channel (
93
+ "stdin" ,
94
+ connection_cfg ,
95
+ identity = self .session_id .encode (),
96
+ )
86
97
87
98
def listen_channels (self ):
88
99
self .tasks .append (asyncio .create_task (self .listen_iopub ()))
89
100
self .tasks .append (asyncio .create_task (self .listen_shell ()))
101
+ self .tasks .append (asyncio .create_task (self .listen_stdin ()))
90
102
91
103
async def stop (self ) -> None :
92
104
self .kernel_process .kill ()
@@ -111,6 +123,13 @@ async def listen_shell(self):
111
123
if msg_id in self .execute_requests .keys ():
112
124
self .execute_requests [msg_id ]["shell_msg" ].put_nowait (msg )
113
125
126
+ async def listen_stdin (self ):
127
+ while True :
128
+ msg = await receive_message (self .stdin_channel , change_str_to_date = True )
129
+ msg_id = msg ["parent_header" ].get ("msg_id" )
130
+ if msg_id in self .execute_requests .keys ():
131
+ self .execute_requests [msg_id ]["stdin_msg" ].put_nowait (msg )
132
+
114
133
async def execute (
115
134
self ,
116
135
ycell : Map ,
@@ -121,7 +140,7 @@ async def execute(
121
140
if ycell ["cell_type" ] != "code" :
122
141
return
123
142
ycell ["execution_state" ] = "busy"
124
- content = {"code" : str (ycell ["source" ]), "silent" : False }
143
+ content = {"code" : str (ycell ["source" ]), "silent" : False , "allow_stdin" : True }
125
144
msg = create_message (
126
145
"execute_request" , content , session_id = self .session_id , msg_id = str (self .msg_cnt )
127
146
)
@@ -134,6 +153,7 @@ async def execute(
134
153
self .execute_requests [msg_id ] = {
135
154
"iopub_msg" : asyncio .Queue (),
136
155
"shell_msg" : asyncio .Queue (),
156
+ "stdin_msg" : asyncio .Queue (),
137
157
}
138
158
if wait_for_executed :
139
159
deadline = time .time () + timeout
@@ -165,21 +185,75 @@ async def execute(
165
185
ycell ["execution_state" ] = "idle"
166
186
del self .execute_requests [msg_id ]
167
187
else :
168
- self .tasks .append (asyncio .create_task (self ._handle_iopub (msg_id , ycell )))
188
+ stdin_task = asyncio .create_task (self ._handle_stdin (msg_id , ycell ))
189
+ self .tasks .append (stdin_task )
190
+ self .tasks .append (asyncio .create_task (self ._handle_iopub (msg_id , ycell , stdin_task )))
169
191
170
- async def _handle_iopub (self , msg_id : str , ycell : Map ) -> None :
192
+ async def _handle_iopub (self , msg_id : str , ycell : Map , stdin_task : asyncio . Task ) -> None :
171
193
while True :
172
194
msg = await self .execute_requests [msg_id ]["iopub_msg" ].get ()
173
195
await self ._handle_outputs (ycell ["outputs" ], msg )
174
196
if (
175
197
(msg ["header" ]["msg_type" ] == "status"
176
198
and msg ["content" ]["execution_state" ] == "idle" )
177
199
):
200
+ stdin_task .cancel ()
178
201
msg = await self .execute_requests [msg_id ]["shell_msg" ].get ()
179
202
with ycell .doc .transaction ():
180
203
ycell ["execution_count" ] = msg ["content" ]["execution_count" ]
181
204
ycell ["execution_state" ] = "idle"
182
205
206
+ async def _handle_stdin (self , msg_id : str , ycell : Map ) -> None :
207
+ while True :
208
+ msg = await self .execute_requests [msg_id ]["stdin_msg" ].get ()
209
+ if msg ["msg_type" ] == "input_request" :
210
+ content = msg ["content" ]
211
+ prompt = content ["prompt" ]
212
+ password = content ["password" ]
213
+ stdin_output = Map (
214
+ {
215
+ "output_type" : "stdin" ,
216
+ "submitted" : False ,
217
+ "password" : password ,
218
+ "prompt" : prompt ,
219
+ "value" : Text (),
220
+ }
221
+ )
222
+ outputs : Array = cast (Array , ycell .get ("outputs" ))
223
+ stdin_idx = len (outputs )
224
+ outputs .append (stdin_output )
225
+ stdin_output .observe (
226
+ partial (self ._handle_stdin_submission , outputs , stdin_idx , password , prompt )
227
+ )
228
+
229
+ def _handle_stdin_submission (self , outputs , stdin_idx , password , prompt , event ):
230
+ if event .target ["submitted" ]:
231
+ # send input reply to kernel
232
+ value = str (event .target ["value" ])
233
+ content = {"value" : value }
234
+ msg = create_message (
235
+ "input_reply" , content , session_id = self .session_id , msg_id = str (self .msg_cnt )
236
+ )
237
+ task0 = asyncio .create_task (
238
+ send_message (msg , self .stdin_channel , self .key , change_date_to_str = True )
239
+ )
240
+ if password :
241
+ value = "········"
242
+ value = f"{ prompt } { value } "
243
+ task1 = asyncio .create_task (self ._change_stdin_to_stream (outputs , stdin_idx , value ))
244
+ self ._background_tasks .add (task0 )
245
+ self ._background_tasks .add (task1 )
246
+ task0 .add_done_callback (self ._background_tasks .discard )
247
+ task1 .add_done_callback (self ._background_tasks .discard )
248
+
249
+ async def _change_stdin_to_stream (self , outputs , stdin_idx , value ):
250
+ # replace stdin output with stream output
251
+ outputs [stdin_idx ] = {
252
+ "output_type" : "stream" ,
253
+ "name" : "stdin" ,
254
+ "text" : value + '\n ' ,
255
+ }
256
+
183
257
async def _handle_comms (self ) -> None :
184
258
if self .yjs is None or self .yjs .widgets is None : # type: ignore
185
259
return
0 commit comments