Skip to content

Commit 440c6b8

Browse files
committed
Fix unit tests
1 parent eab3c4e commit 440c6b8

File tree

2 files changed

+66
-19
lines changed

2 files changed

+66
-19
lines changed

jupyter_server_nbmodel/handlers.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def get(self, kernel_id: str, uid: str) -> t.Any:
8282
else:
8383
return None
8484

85-
def put(self, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map) -> str:
85+
def put(
86+
self, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map | None
87+
) -> str:
8688
"""Add a asynchronous execution request.
8789
8890
Args:
@@ -99,18 +101,27 @@ def put(self, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.M
99101
)
100102
return uid
101103

102-
def _stdin_hook(self, kernel_id, msg) -> None:
104+
def _stdin_hook(self, kernel_id: str, msg: dict) -> None:
103105
get_logger().info(f"Execution request {kernel_id} received a input request {msg!s}")
104106
if kernel_id in self.__pending_inputs:
105-
get_logger().error(f"Execution request {kernel_id} received a input request while waiting for an input.\n{msg}")
107+
get_logger().error(
108+
f"Execution request {kernel_id} received a input request while waiting for an input.\n{msg}" # noqa: E501
109+
)
106110

107111
header = msg["header"].copy()
108112
header["date"] = header["date"].isoformat()
109-
self.__pending_inputs[kernel_id] = {"parent_header": header, "input_request": msg["content"]}
113+
self.__pending_inputs[kernel_id] = {
114+
"parent_header": header,
115+
"input_request": msg["content"],
116+
}
110117

111118

112119
async def execute_task(
113-
uid, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map, stdin_hook
120+
uid,
121+
km: jupyter_client.manager.KernelManager,
122+
snippet: str,
123+
ycell: y.Map | None,
124+
stdin_hook: t.Callable[[dict], None] | None,
114125
) -> t.Any:
115126
try:
116127
get_logger().debug(f"Will execute request {uid}.")
@@ -134,15 +145,11 @@ async def execute_task(
134145

135146
async def _execute_snippet(
136147
uid: str,
137-
km: jupyter_client.client.KernelClient,
148+
km: jupyter_client.manager.KernelManager,
138149
snippet: str,
139-
ycell: y.Map,
140-
stdin_hook,
150+
ycell: y.Map | None,
151+
stdin_hook: t.Callable[[dict], None] | None,
141152
) -> dict[str, t.Any]:
142-
client = km.client()
143-
client.session.session = uid
144-
# FIXME
145-
# client.session.username = username
146153

147154
if ycell is not None:
148155
# Reset cell
@@ -151,6 +158,10 @@ async def _execute_snippet(
151158
ycell["execution_count"] = None
152159

153160
outputs = []
161+
client = km.client()
162+
client.session.session = uid
163+
# FIXME
164+
# client.session.username = username
154165

155166
# FIXME we don't check if the session is consistent (aka the kernel is linked to the document)
156167
# - should we?
@@ -175,6 +186,7 @@ async def _execute_snippet(
175186
"outputs": json.dumps(outputs),
176187
}
177188
finally:
189+
client.stop_channels()
178190
del client
179191

180192

@@ -252,6 +264,7 @@ async def post(self, kernel_id: str) -> None:
252264
body = self.get_json_body()
253265

254266
snippet = body.get("code")
267+
ycell = None
255268
# From RTC model
256269
if snippet is None:
257270
document_id = body.get("document_id")
@@ -280,7 +293,6 @@ async def post(self, kernel_id: str) -> None:
280293
raise tornado.web.HTTPError(status_code=HTTPStatus.NOT_FOUND, reason=msg)
281294

282295
ycells = filter(lambda c: c["id"] == cell_id, notebook.ycells)
283-
ycell = None
284296
try:
285297
ycell = next(ycells)
286298
except StopIteration:
@@ -337,9 +349,12 @@ async def post(self, kernel_id: str) -> None:
337349
try:
338350
# only send stdin reply if there *was not* another request
339351
# or execution finished while we were reading.
340-
if not (await client.stdin_channel.msg_ready() or await client.shell_channel.msg_ready()):
352+
if not (
353+
await client.stdin_channel.msg_ready() or await client.shell_channel.msg_ready()
354+
):
341355
client.input(body["input"])
342356
finally:
357+
client.stop_channels()
343358
del client
344359

345360

jupyter_server_nbmodel/tests/test_handlers.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,41 @@
11
import asyncio
2+
import datetime
23
import json
3-
4-
from jupyter_client.kernelspec import NATIVE_KERNEL_NAME
4+
import re
55

66
import pytest
7+
from jupyter_client.kernelspec import NATIVE_KERNEL_NAME
78

89
TEST_TIMEOUT = 60
10+
SLEEP = 0.1
11+
12+
13+
REQUEST_REGEX = re.compile(r"^/api/kernels/\w+-\w+-\w+-\w+-\w+/requests/\w+-\w+-\w+-\w+-\w+$")
14+
15+
16+
async def _wait_request(fetch, endpoint: str):
17+
"""Poll periodically to fetch the execution request result."""
18+
start_time = datetime.datetime.now()
19+
20+
while (datetime.datetime.now() - start_time).total_seconds() < 0.9 * TEST_TIMEOUT:
21+
await asyncio.sleep(SLEEP)
22+
response = await fetch(endpoint)
23+
response.rethrow()
24+
if response.code != 202:
25+
return response
26+
27+
raise TimeoutError(f"Request {endpoint} timed out.")
28+
29+
30+
async def wait_for_request(fetch, *args, **kwargs):
31+
"""Wait for execution request."""
32+
r = await fetch(*args, **kwargs)
33+
assert r.code == 202
34+
location = r.headers["Location"]
35+
assert REQUEST_REGEX.match(location) is not None
36+
37+
ans = await _wait_request(fetch, location)
38+
return ans
939

1040

1141
@pytest.fixture()
@@ -39,7 +69,7 @@ async def _(kernel_id, ready=None):
3969
(
4070
"""from IPython.display import HTML
4171
HTML('<p><b>Jupyter</b> rocks.</p>')""",
42-
'{"output_type": "execute_result", "metadata": {}, "data": {"text/plain": "<IPython.core.display.HTML object>", "text/html": "<p><b>Jupyter</b> rocks.</p>"}, "execution_count": 1}',
72+
'{"output_type": "execute_result", "metadata": {}, "data": {"text/plain": "<IPython.core.display.HTML object>", "text/html": "<p><b>Jupyter</b> rocks.</p>"}, "execution_count": 1}', # noqa: E501
4373
),
4474
),
4575
)
@@ -50,7 +80,8 @@ async def test_post_execute(jp_fetch, pending_kernel_is_ready, snippet, output):
5080
kernel = json.loads(r.body.decode())
5181
await pending_kernel_is_ready(kernel["id"])
5282

53-
response = await jp_fetch(
83+
response = await wait_for_request(
84+
jp_fetch,
5485
"api",
5586
"kernels",
5687
kernel["id"],
@@ -89,7 +120,8 @@ async def test_post_erroneous_execute(jp_fetch, pending_kernel_is_ready, snippet
89120
kernel = json.loads(r.body.decode())
90121
await pending_kernel_is_ready(kernel["id"])
91122

92-
response = await jp_fetch(
123+
response = await wait_for_request(
124+
jp_fetch,
93125
"api",
94126
"kernels",
95127
kernel["id"],

0 commit comments

Comments
 (0)