Skip to content

Commit 35ffe5d

Browse files
ojarjurblink1073
andauthored
Merge the gateway handlers into the standard handlers. (#1261)
Co-authored-by: Steven Silvester <[email protected]>
1 parent 54d7292 commit 35ffe5d

File tree

7 files changed

+293
-14
lines changed

7 files changed

+293
-14
lines changed

docs/source/api/jupyter_server.gateway.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ Submodules
55
----------
66

77

8+
.. automodule:: jupyter_server.gateway.connections
9+
:members:
10+
:undoc-members:
11+
:show-inheritance:
12+
13+
814
.. automodule:: jupyter_server.gateway.gateway_client
915
:members:
1016
:undoc-members:

jupyter_server/gateway/connections.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""Gateway connection classes."""
2+
# Copyright (c) Jupyter Development Team.
3+
# Distributed under the terms of the Modified BSD License.
4+
5+
import asyncio
6+
import logging
7+
import random
8+
from typing import Any, cast
9+
10+
import tornado.websocket as tornado_websocket
11+
from tornado.concurrent import Future
12+
from tornado.escape import json_decode, url_escape, utf8
13+
from tornado.httpclient import HTTPRequest
14+
from tornado.ioloop import IOLoop
15+
from traitlets import Bool, Instance, Int
16+
17+
from ..services.kernels.connection.base import BaseKernelWebsocketConnection
18+
from ..utils import url_path_join
19+
from .managers import GatewayClient
20+
21+
22+
class GatewayWebSocketConnection(BaseKernelWebsocketConnection):
23+
"""Web socket connection that proxies to a kernel/enterprise gateway."""
24+
25+
ws = Instance(klass=tornado_websocket.WebSocketClientConnection, allow_none=True)
26+
27+
ws_future = Instance(default_value=Future(), klass=Future)
28+
29+
disconnected = Bool(False)
30+
31+
retry = Int(0)
32+
33+
async def connect(self):
34+
"""Connect to the socket."""
35+
# websocket is initialized before connection
36+
self.ws = None
37+
ws_url = url_path_join(
38+
GatewayClient.instance().ws_url,
39+
GatewayClient.instance().kernels_endpoint,
40+
url_escape(self.kernel_id),
41+
"channels",
42+
)
43+
self.log.info(f"Connecting to {ws_url}")
44+
kwargs: dict = {}
45+
kwargs = GatewayClient.instance().load_connection_args(**kwargs)
46+
47+
request = HTTPRequest(ws_url, **kwargs)
48+
self.ws_future = cast(Future, tornado_websocket.websocket_connect(request))
49+
self.ws_future.add_done_callback(self._connection_done)
50+
51+
loop = IOLoop.current()
52+
loop.add_future(self.ws_future, lambda future: self._read_messages())
53+
54+
def _connection_done(self, fut):
55+
"""Handle a finished connection."""
56+
if (
57+
not self.disconnected and fut.exception() is None
58+
): # prevent concurrent.futures._base.CancelledError
59+
self.ws = fut.result()
60+
self.retry = 0
61+
self.log.debug(f"Connection is ready: ws: {self.ws}")
62+
else:
63+
self.log.warning(
64+
"Websocket connection has been closed via client disconnect or due to error. "
65+
"Kernel with ID '{}' may not be terminated on GatewayClient: {}".format(
66+
self.kernel_id, GatewayClient.instance().url
67+
)
68+
)
69+
70+
def disconnect(self):
71+
"""Handle a disconnect."""
72+
self.disconnected = True
73+
if self.ws is not None:
74+
# Close connection
75+
self.ws.close()
76+
elif not self.ws_future.done():
77+
# Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
78+
self.ws_future.cancel()
79+
self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")
80+
81+
async def _read_messages(self):
82+
"""Read messages from gateway server."""
83+
while self.ws is not None:
84+
message = None
85+
if not self.disconnected:
86+
try:
87+
message = await self.ws.read_message()
88+
except Exception as e:
89+
self.log.error(
90+
f"Exception reading message from websocket: {e}"
91+
) # , exc_info=True)
92+
if message is None:
93+
if not self.disconnected:
94+
self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
95+
break
96+
self.handle_outgoing_message(
97+
message
98+
) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
99+
else: # ws cancelled - stop reading
100+
break
101+
102+
# NOTE(esevan): if websocket is not disconnected by client, try to reconnect.
103+
if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max:
104+
jitter = random.randint(10, 100) * 0.01 # noqa
105+
retry_interval = (
106+
min(
107+
GatewayClient.instance().gateway_retry_interval * (2**self.retry),
108+
GatewayClient.instance().gateway_retry_interval_max,
109+
)
110+
+ jitter
111+
)
112+
self.retry += 1
113+
self.log.info(
114+
"Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s",
115+
retry_interval,
116+
self.retry,
117+
GatewayClient.instance().gateway_retry_max,
118+
self.kernel_id,
119+
)
120+
await asyncio.sleep(retry_interval)
121+
loop = IOLoop.current()
122+
loop.spawn_callback(self.connect)
123+
124+
def handle_outgoing_message(self, incoming_msg: str, *args: Any) -> None:
125+
"""Send message to the notebook client."""
126+
try:
127+
self.websocket_handler.write_message(incoming_msg)
128+
except tornado_websocket.WebSocketClosedError:
129+
if self.log.isEnabledFor(logging.DEBUG):
130+
msg_summary = GatewayWebSocketConnection._get_message_summary(
131+
json_decode(utf8(incoming_msg))
132+
)
133+
self.log.debug(
134+
"Notebook client closed websocket connection - message dropped: {}".format(
135+
msg_summary
136+
)
137+
)
138+
139+
def handle_incoming_message(self, message: str) -> None:
140+
"""Send message to gateway server."""
141+
if self.ws is None:
142+
loop = IOLoop.current()
143+
loop.add_future(self.ws_future, lambda future: self.handle_incoming_message(message))
144+
else:
145+
self._write_message(message)
146+
147+
def _write_message(self, message):
148+
"""Send message to gateway server."""
149+
try:
150+
if not self.disconnected and self.ws is not None:
151+
self.ws.write_message(message)
152+
except Exception as e:
153+
self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)
154+
155+
@staticmethod
156+
def _get_message_summary(message):
157+
"""Get a summary of a message."""
158+
summary = []
159+
message_type = message["msg_type"]
160+
summary.append(f"type: {message_type}")
161+
162+
if message_type == "status":
163+
summary.append(", state: {}".format(message["content"]["execution_state"]))
164+
elif message_type == "error":
165+
summary.append(
166+
", {}:{}:{}".format(
167+
message["content"]["ename"],
168+
message["content"]["evalue"],
169+
message["content"]["traceback"],
170+
)
171+
)
172+
else:
173+
summary.append(", ...") # don't display potentially sensitive data
174+
175+
return "".join(summary)

jupyter_server/gateway/handlers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import mimetypes
77
import os
88
import random
9+
import warnings
910
from typing import Optional, cast
1011

1112
from jupyter_client.session import Session
@@ -21,6 +22,13 @@
2122
from ..utils import url_path_join
2223
from .managers import GatewayClient
2324

25+
warnings.warn(
26+
"The jupyter_server.gateway.handlers module is deprecated and will not be supported in Jupyter Server 3.0",
27+
DeprecationWarning,
28+
stacklevel=2,
29+
)
30+
31+
2432
# Keepalive ping interval (default: 30 seconds)
2533
GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30"))
2634

jupyter_server/kernelspecs/handlers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Kernelspecs API Handlers."""
2+
import mimetypes
3+
24
from jupyter_core.utils import ensure_async
35
from tornado import web
46

@@ -27,6 +29,26 @@ async def get(self, kernel_name, path, include_body=True):
2729
ksm = self.kernel_spec_manager
2830
if path.lower().endswith(".png"):
2931
self.set_header("Cache-Control", f"max-age={60*60*24*30}")
32+
ksm = self.kernel_spec_manager
33+
if hasattr(ksm, "get_kernel_spec_resource"):
34+
# If the kernel spec manager defines a method to get kernelspec resources,
35+
# then use that instead of trying to read from disk.
36+
kernel_spec_res = await ksm.get_kernel_spec_resource(kernel_name, path)
37+
if kernel_spec_res is not None:
38+
# We have to explicitly specify the `absolute_path` attribute so that
39+
# the underlying StaticFileHandler methods can calculate an etag.
40+
self.absolute_path = path
41+
mimetype: str = mimetypes.guess_type(path)[0] or "text/plain"
42+
self.set_header("Content-Type", mimetype)
43+
self.finish(kernel_spec_res)
44+
return
45+
else:
46+
self.log.warning(
47+
"Kernelspec resource '{}' for '{}' not found. Kernel spec manager may"
48+
" not support resource serving. Falling back to reading from disk".format(
49+
path, kernel_name
50+
)
51+
)
3052
try:
3153
kspec = await ensure_async(ksm.get_kernel_spec(kernel_name))
3254
self.root = kspec.resource_dir

jupyter_server/serverapp.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from jupyter_server.extension.config import ExtensionConfigManager
8888
from jupyter_server.extension.manager import ExtensionManager
8989
from jupyter_server.extension.serverextension import ServerExtensionApp
90+
from jupyter_server.gateway.connections import GatewayWebSocketConnection
9091
from jupyter_server.gateway.managers import (
9192
GatewayClient,
9293
GatewayKernelSpecManager,
@@ -433,17 +434,6 @@ def init_handlers(self, default_services, settings):
433434
# And from identity provider
434435
handlers.extend(settings["identity_provider"].get_handlers())
435436

436-
# If gateway mode is enabled, replace appropriate handlers to perform redirection
437-
if GatewayClient.instance().gateway_enabled:
438-
# for each handler required for gateway, locate its pattern
439-
# in the current list and replace that entry...
440-
gateway_handlers = load_handlers("jupyter_server.gateway.handlers")
441-
for _, gwh in enumerate(gateway_handlers):
442-
for j, h in enumerate(handlers):
443-
if gwh[0] == h[0]:
444-
handlers[j] = (gwh[0], gwh[1])
445-
break
446-
447437
# register base handlers last
448438
handlers.extend(load_handlers("jupyter_server.base.handlers"))
449439

@@ -796,6 +786,7 @@ class ServerApp(JupyterApp):
796786
GatewayMappingKernelManager,
797787
GatewayKernelSpecManager,
798788
GatewaySessionManager,
789+
GatewayWebSocketConnection,
799790
GatewayClient,
800791
Authorizer,
801792
EventLogger,
@@ -1505,12 +1496,17 @@ def _default_session_manager_class(self):
15051496
return SessionManager
15061497

15071498
kernel_websocket_connection_class = Type(
1508-
default_value=ZMQChannelsWebsocketConnection,
15091499
klass=BaseKernelWebsocketConnection,
15101500
config=True,
15111501
help=_i18n("The kernel websocket connection class to use."),
15121502
)
15131503

1504+
@default("kernel_websocket_connection_class")
1505+
def _default_kernel_websocket_connection_class(self):
1506+
if self.gateway_config.gateway_enabled:
1507+
return "jupyter_server.gateway.connections.GatewayWebSocketConnection"
1508+
return ZMQChannelsWebsocketConnection
1509+
15141510
config_manager_class = Type(
15151511
default_value=ConfigManager,
15161512
config=True,
@@ -2876,7 +2872,19 @@ async def _cleanup(self):
28762872
self.remove_browser_open_files()
28772873
await self.cleanup_extensions()
28782874
await self.cleanup_kernels()
2879-
await self.kernel_websocket_connection_class.close_all()
2875+
try:
2876+
await self.kernel_websocket_connection_class.close_all()
2877+
except AttributeError:
2878+
# This can happen in two different scenarios:
2879+
#
2880+
# 1. During tests, where the _cleanup method is invoked without
2881+
# the corresponding initialize method having been invoked.
2882+
# 2. If the provided `kernel_websocket_connection_class` does not
2883+
# implement the `close_all` class method.
2884+
#
2885+
# In either case, we don't need to do anything and just want to treat
2886+
# the raised error as a no-op.
2887+
pass
28802888
if getattr(self, "kernel_manager", None):
28812889
self.kernel_manager.__del__()
28822890
if getattr(self, "session_manager", None):

tests/test_gateway.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@
1515
import pytest
1616
import tornado
1717
from jupyter_core.utils import ensure_async
18+
from tornado.concurrent import Future
1819
from tornado.httpclient import HTTPRequest, HTTPResponse
20+
from tornado.httputil import HTTPServerRequest
21+
from tornado.queues import Queue
1922
from tornado.web import HTTPError
2023
from traitlets import Int, Unicode
2124
from traitlets.config import Config
2225

26+
from jupyter_server.gateway.connections import GatewayWebSocketConnection
2327
from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase, NoOpTokenRenewer
2428
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient, GatewayKernelManager
29+
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler
2530

2631
from .utils import expected_http_error
2732

@@ -659,6 +664,61 @@ async def test_channel_queue_get_msg_when_response_router_had_finished():
659664
await queue.get_msg()
660665

661666

667+
class MockWebSocketClientConnection(tornado.websocket.WebSocketClientConnection):
668+
def __init__(self, *args, **kwargs):
669+
self._msgs: Queue = Queue(2)
670+
self._msgs.put_nowait('{"msg_type": "status", "content": {"execution_state": "starting"}}')
671+
672+
def write_message(self, message, *args, **kwargs):
673+
return self._msgs.put(message)
674+
675+
def read_message(self, *args, **kwargs):
676+
return self._msgs.get()
677+
678+
679+
def mock_websocket_connect():
680+
def helper(request):
681+
fut: Future = Future()
682+
mock_client = MockWebSocketClientConnection()
683+
fut.set_result(mock_client)
684+
return fut
685+
686+
return helper
687+
688+
689+
@patch("tornado.websocket.websocket_connect", mock_websocket_connect())
690+
async def test_websocket_connection_closed(init_gateway, jp_serverapp, jp_fetch, caplog):
691+
# Create the kernel and get the kernel manager...
692+
kernel_id = await create_kernel(jp_fetch, "kspec_foo")
693+
km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id)
694+
695+
# Create the KernelWebsocketHandler...
696+
request = HTTPServerRequest("foo", "GET")
697+
request.connection = MagicMock()
698+
handler = KernelWebsocketHandler(jp_serverapp.web_app, request)
699+
700+
# Force the websocket handler to raise a closed error if we try to write a message
701+
# to the client.
702+
handler.ws_connection = MagicMock()
703+
handler.ws_connection.is_closing = lambda: True
704+
705+
# Create the GatewayWebSocketConnection and attach it to the handler...
706+
conn = GatewayWebSocketConnection(parent=km, websocket_handler=handler)
707+
handler.connection = conn
708+
await conn.connect()
709+
710+
# Processing websocket messages happens in separate coroutines and any
711+
# errors in that process will show up in logs, but not bubble up to the
712+
# caller.
713+
#
714+
# To check for these, we wait for the server to stop and then check the
715+
# logs for errors.
716+
await jp_serverapp._cleanup()
717+
for _, level, message in caplog.record_tuples:
718+
if level >= logging.ERROR:
719+
pytest.fail(f"Logs contain an error: {message}")
720+
721+
662722
#
663723
# Test methods below...
664724
#

0 commit comments

Comments
 (0)