Skip to content

Commit 339038b

Browse files
authored
Add more websocket connection tests and fix bugs (#1085)
1 parent 890b882 commit 339038b

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

jupyter_server/services/kernels/connection/channels.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,9 @@ def handle_outgoing_message(self, stream: str, outgoing_msg: list) -> None:
485485
else:
486486
msg = self.session.deserialize(fed_msg_list)
487487

488+
if isinstance(stream, str):
489+
stream = self.channels[stream]
490+
488491
channel = getattr(stream, "channel", None)
489492
parts = fed_msg_list[1:]
490493

@@ -534,7 +537,7 @@ def _reserialize_reply(self, msg_or_list, channel=None):
534537
return json.dumps(msg, default=json_default)
535538

536539
def select_subprotocol(self, subprotocols):
537-
preferred_protocol = self.settings.get("kernel_ws_protocol")
540+
preferred_protocol = self.kernel_ws_protocol
538541
if preferred_protocol is None:
539542
preferred_protocol = "v1.kernel.websocket.jupyter.org"
540543
elif preferred_protocol == "":
@@ -792,7 +795,7 @@ def on_restart_failed(self):
792795
self._send_status_message("dead")
793796

794797
def _on_error(self, channel, msg, msg_list):
795-
if self.kernel_manager.allow_tracebacks:
798+
if self.multi_kernel_manager.allow_tracebacks:
796799
return
797800

798801
if channel == "iopub":
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import asyncio
2+
import json
3+
from unittest.mock import MagicMock
4+
5+
from jupyter_client.jsonutil import json_clean, json_default
6+
from jupyter_client.session import Session
7+
from tornado.httpserver import HTTPRequest
8+
from tornado.websocket import WebSocketHandler
9+
10+
from jupyter_server.serverapp import ServerApp
11+
from jupyter_server.services.kernels.connection.channels import (
12+
ZMQChannelsWebsocketConnection,
13+
)
14+
15+
16+
async def test_websocket_connection(jp_serverapp):
17+
app: ServerApp = jp_serverapp
18+
kernel_id = await app.kernel_manager.start_kernel()
19+
kernel = app.kernel_manager.get_kernel(kernel_id)
20+
request = HTTPRequest("foo", "GET")
21+
request.connection = MagicMock()
22+
handler = WebSocketHandler(app.web_app, request)
23+
handler.ws_connection = MagicMock()
24+
handler.ws_connection.is_closing = lambda: False
25+
conn = ZMQChannelsWebsocketConnection(parent=kernel, websocket_handler=handler)
26+
await conn.prepare()
27+
conn.connect()
28+
await asyncio.wrap_future(conn.nudge())
29+
session: Session = kernel.session
30+
msg = session.msg("data_pub", content={"a": "b"})
31+
data = json.dumps(
32+
json_clean(msg),
33+
default=json_default,
34+
ensure_ascii=False,
35+
allow_nan=False,
36+
)
37+
conn.handle_incoming_message(data)
38+
conn.handle_outgoing_message("iopub", session.serialize(msg))
39+
assert (
40+
conn.select_subprotocol(["v1.kernel.websocket.jupyter.org"])
41+
== "v1.kernel.websocket.jupyter.org"
42+
)
43+
conn.write_stderr("test", {})
44+
conn.on_kernel_restarted()
45+
conn.on_restart_failed()
46+
conn._on_error("shell", msg, session.serialize(msg))

0 commit comments

Comments
 (0)