Skip to content

Commit ffcfb02

Browse files
committed
FIXME attempt to support stdin -> for now server is locked :s
1 parent 96ff3f3 commit ffcfb02

File tree

2 files changed

+95
-40
lines changed

2 files changed

+95
-40
lines changed

jupyter_server_nbmodel/extension.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from jupyter_server.extension.application import ExtensionApp
44
from jupyter_server.services.kernels.handlers import _kernel_id_regex
55

6-
from .handlers import ExecuteHandler, ExecutionStack, RequestHandler
6+
from .handlers import ExecuteHandler, ExecutionStack, InputHandler, RequestHandler
77
from .log import get_logger
88

99
RTC_EXTENSIONAPP_NAME = "jupyter_server_ydoc"
@@ -36,6 +36,10 @@ def initialize_handlers(self):
3636
ExecuteHandler,
3737
{"ydoc_extension": rtc_extension, "execution_stack": self.__tasks},
3838
),
39+
(
40+
f"/api/kernels/{_kernel_id_regex}/input",
41+
InputHandler,
42+
),
3943
(
4044
f"/api/kernels/{_kernel_id_regex}/requests/{_request_id_regex}",
4145
RequestHandler,

jupyter_server_nbmodel/handlers.py

Lines changed: 90 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ExecutionStack:
3737
"""
3838

3939
def __init__(self):
40+
self.__pending_inputs: dict[str, dict] = {}
4041
self.__tasks: dict[str, asyncio.Task] = {}
4142

4243
def __del__(self):
@@ -55,11 +56,12 @@ def cancel(self, uid: str) -> None:
5556

5657
self.__tasks[uid].cancel()
5758

58-
def get(self, uid: str) -> t.Any:
59+
def get(self, kernel_id: str, uid: str) -> t.Any:
5960
"""Get the request ``uid`` results or None.
6061
6162
Args:
62-
uid (str): Request index
63+
kernel_id : Kernel identifier
64+
uid : Request index
6365
6466
Returns:
6567
Any: None if the request is pending else its result
@@ -71,13 +73,16 @@ def get(self, uid: str) -> t.Any:
7173
if uid not in self.__tasks:
7274
raise ValueError(f"Request {uid} does not exists.")
7375

76+
if kernel_id in self.__pending_inputs:
77+
return self.__pending_inputs.pop(kernel_id)
78+
7479
if self.__tasks[uid].done():
7580
task = self.__tasks.pop(uid)
7681
return task.result()
7782
else:
7883
return None
7984

80-
def put(self, task: t.Awaitable, *args) -> str:
85+
def put(self, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map) -> str:
8186
"""Add a asynchronous execution request.
8287
8388
Args:
@@ -87,36 +92,57 @@ def put(self, task: t.Awaitable, *args) -> str:
8792
Returns:
8893
Request identifier
8994
"""
90-
uid = uuid.uuid4()
95+
uid = str(uuid.uuid4())
9196

92-
async def execute_task(uid, f, *args) -> t.Any:
93-
try:
94-
get_logger().debug(f"Will execute request {uid}.")
95-
result = await f(*args)
96-
except asyncio.CancelledError:
97-
raise
98-
except Exception as e:
99-
exception_type, _, tb = sys.exc_info()
100-
result = {
101-
"type": exception_type.__qualname__,
102-
"error": str(e),
103-
"message": repr(e),
104-
"traceback": traceback.format_tb(tb),
105-
}
106-
get_logger().error("Error for request %s.", result)
107-
else:
108-
get_logger().debug(f"Has executed request {uid}.")
97+
self.__tasks[uid] = asyncio.create_task(
98+
execute_task(uid, km, snippet, ycell, partial(self._stdin_hook, km.kernel_id))
99+
)
100+
return uid
109101

110-
return result
102+
def _stdin_hook(self, kernel_id, msg) -> None:
103+
get_logger().info(f"Execution request {kernel_id} received a input request {msg!s}")
104+
if kernel_id in self.__pending_inputs:
105+
get_logger().error(f"Execution request {kernel_id} received a input request while waiting for an input.\n{msg}")
106+
107+
header = msg["header"].copy()
108+
header["date"] = header["date"].isoformat()
109+
self.__pending_inputs[kernel_id] = {"parent_header": header, "input_request": msg["content"]}
111110

112-
self.__tasks[uid] = asyncio.create_task(execute_task(uid, task, *args))
113-
return uid
111+
112+
async def execute_task(
113+
uid, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map, stdin_hook
114+
) -> t.Any:
115+
try:
116+
get_logger().debug(f"Will execute request {uid}.")
117+
result = await _execute_snippet(uid, km, snippet, ycell, stdin_hook)
118+
except asyncio.CancelledError:
119+
raise
120+
except Exception as e:
121+
exception_type, _, tb = sys.exc_info()
122+
result = {
123+
"type": exception_type.__qualname__,
124+
"error": str(e),
125+
"message": repr(e),
126+
"traceback": traceback.format_tb(tb),
127+
}
128+
get_logger().error("Error for request %s.", result)
129+
else:
130+
get_logger().debug(f"Has executed request {uid}.")
131+
132+
return result
114133

115134

116-
async def execute_snippet(
117-
km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map
135+
async def _execute_snippet(
136+
uid: str,
137+
km: jupyter_client.client.KernelClient,
138+
snippet: str,
139+
ycell: y.Map,
140+
stdin_hook,
118141
) -> dict[str, t.Any]:
119142
client = km.client()
143+
client.session.session = uid
144+
# FIXME
145+
# client.session.username = username
120146

121147
if ycell is not None:
122148
# Reset cell
@@ -125,15 +151,14 @@ async def execute_snippet(
125151

126152
outputs = []
127153

128-
# FIXME set the username of client.session to server user
129154
# FIXME we don't check if the session is consistent (aka the kernel is linked to the document)
130155
# - should we?
131156
try:
132157
reply = await ensure_async(
133158
client.execute_interactive(
134159
snippet,
135160
output_hook=partial(_output_hook, ycell, outputs),
136-
stdin_hook=_stdin_hook if client.allow_stdin else None,
161+
stdin_hook=stdin_hook if client.allow_stdin else None,
137162
)
138163
)
139164

@@ -191,9 +216,6 @@ def _output_hook(ycell, outputs, msg) -> None:
191216
# FIXME
192217
...
193218

194-
def _stdin_hook(msg) -> None:
195-
get_logger().info("Code snippet execution is waiting for an input.")
196-
197219

198220
class ExecuteHandler(ExtensionHandlerMixin, APIHandler):
199221
"""Handle request for snippet execution."""
@@ -288,13 +310,38 @@ async def post(self, kernel_id: str) -> None:
288310
get_logger().error(msg, exc_info=e)
289311
raise tornado.web.HTTPError(status_code=HTTPStatus.NOT_FOUND, reason=msg) from e
290312

291-
uid = self._execution_stack.put(execute_snippet, km, snippet, ycell)
313+
uid = self._execution_stack.put(km, snippet, ycell)
292314

293315
self.set_status(HTTPStatus.ACCEPTED)
294316
self.set_header("Location", f"/api/kernels/{kernel_id}/requests/{uid}")
295317
self.finish("{}")
296318

297319

320+
class InputHandler(ExtensionHandlerMixin, APIHandler):
321+
"""Handle request for input reply."""
322+
323+
@tornado.web.authenticated
324+
async def post(self, kernel_id: str) -> None:
325+
body = self.get_json_body()
326+
327+
try:
328+
km = self.kernel_manager.get_kernel(kernel_id)
329+
except KeyError as e:
330+
msg = f"Unknown kernel with id: {kernel_id}"
331+
get_logger().error(msg, exc_info=e)
332+
raise tornado.web.HTTPError(status_code=HTTPStatus.NOT_FOUND, reason=msg) from e
333+
334+
client = km.client()
335+
336+
try:
337+
# only send stdin reply if there *was not* another request
338+
# or execution finished while we were reading.
339+
if not (await client.stdin_channel.msg_ready() or await client.shell_channel.msg_ready()):
340+
client.input(body["input"])
341+
finally:
342+
del client
343+
344+
298345
class RequestHandler(ExtensionHandlerMixin, APIHandler):
299346
"""Handler for /api/kernels/<kernel_id>/requests/<request_id>"""
300347

@@ -305,14 +352,15 @@ def initialize(
305352
self._stack = execution_stack
306353

307354
@tornado.web.authenticated
308-
def get(self, kernel_id: str, uid: str) -> None:
355+
def get(self, kernel_id: str, request_id: str) -> None:
309356
"""`GET /api/kernels/<kernel_id>/requests/<id>` Returns the request ``uid`` status.
310357
311358
Status are:
312359
313-
* 200: Task result is returned
314-
* 202: Task is pending
315-
* 500: Task ends with errors
360+
* 200: Request result is returned
361+
* 202: Request is pending
362+
* 300: Request has a pending input
363+
* 500: Request ends with errors
316364
317365
Args:
318366
index: Request identifier
@@ -321,7 +369,7 @@ def get(self, kernel_id: str, uid: str) -> None:
321369
404 if request ``uid`` does not exist
322370
"""
323371
try:
324-
r = self._stack.get(uid)
372+
r = self._stack.get(kernel_id, request_id)
325373
except ValueError as err:
326374
raise tornado.web.HTTPError(404, reason=str(err)) from err
327375
else:
@@ -332,12 +380,15 @@ def get(self, kernel_id: str, uid: str) -> None:
332380
if "error" in r:
333381
self.set_status(500)
334382
self.log.debug(f"{r}")
383+
elif "input_request" in r:
384+
self.set_status(300)
385+
self.set_header("Location", f"/api/kernels/{kernel_id}/input")
335386
else:
336387
self.set_status(200)
337388
self.finish(json.dumps(r))
338389

339390
@tornado.web.authenticated
340-
def delete(self, kernel_id: str, uid: str) -> None:
391+
def delete(self, kernel_id: str, request_id: str) -> None:
341392
"""`DELETE /api/kernels/<kernel_id>/requests/<id>` cancels the request ``uid``.
342393
343394
Status are:
@@ -350,7 +401,7 @@ def delete(self, kernel_id: str, uid: str) -> None:
350401
404 if request ``uid`` does not exist
351402
"""
352403
try:
353-
self._stack.cancel(int(uid))
404+
self._stack.cancel(request_id)
354405
except ValueError as err:
355406
raise tornado.web.HTTPError(404, reason=str(err)) from err
356407
else:

0 commit comments

Comments
 (0)