Skip to content

Commit 884b72e

Browse files
authored
New configurable/overridable kernel ZMQ+Websocket connection API (#1047)
* add new configurable websocket api * cleaning up unit tests * more updates for unit tests * all loop to ensure kernel is alive before connecting working unit tests * fix pre-commit errors * handle trait deprecation * cleanup from code review * ignore deprecation warning from zmqhandlers module * move base websocket mixin into its own module
1 parent fe87f69 commit 884b72e

File tree

12 files changed

+1324
-1133
lines changed

12 files changed

+1324
-1133
lines changed

jupyter_server/base/websocket.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import re
2+
from typing import Optional, no_type_check
3+
from urllib.parse import urlparse
4+
5+
from tornado import ioloop
6+
from tornado.iostream import IOStream
7+
8+
# ping interval for keeping websockets alive (30 seconds)
9+
WS_PING_INTERVAL = 30000
10+
11+
12+
class WebSocketMixin:
13+
"""Mixin for common websocket options"""
14+
15+
ping_callback = None
16+
last_ping = 0.0
17+
last_pong = 0.0
18+
stream = None # type: Optional[IOStream]
19+
20+
@property
21+
def ping_interval(self):
22+
"""The interval for websocket keep-alive pings.
23+
24+
Set ws_ping_interval = 0 to disable pings.
25+
"""
26+
return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined]
27+
28+
@property
29+
def ping_timeout(self):
30+
"""If no ping is received in this many milliseconds,
31+
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
32+
Default is max of 3 pings or 30 seconds.
33+
"""
34+
return self.settings.get( # type:ignore[attr-defined]
35+
"ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)
36+
)
37+
38+
@no_type_check
39+
def check_origin(self, origin: Optional[str] = None) -> bool:
40+
"""Check Origin == Host or Access-Control-Allow-Origin.
41+
42+
Tornado >= 4 calls this method automatically, raising 403 if it returns False.
43+
"""
44+
45+
if self.allow_origin == "*" or (
46+
hasattr(self, "skip_check_origin") and self.skip_check_origin()
47+
):
48+
return True
49+
50+
host = self.request.headers.get("Host")
51+
if origin is None:
52+
origin = self.get_origin()
53+
54+
# If no origin or host header is provided, assume from script
55+
if origin is None or host is None:
56+
return True
57+
58+
origin = origin.lower()
59+
origin_host = urlparse(origin).netloc
60+
61+
# OK if origin matches host
62+
if origin_host == host:
63+
return True
64+
65+
# Check CORS headers
66+
if self.allow_origin:
67+
allow = self.allow_origin == origin
68+
elif self.allow_origin_pat:
69+
allow = bool(re.match(self.allow_origin_pat, origin))
70+
else:
71+
# No CORS headers deny the request
72+
allow = False
73+
if not allow:
74+
self.log.warning(
75+
"Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
76+
origin,
77+
host,
78+
)
79+
return allow
80+
81+
def clear_cookie(self, *args, **kwargs):
82+
"""meaningless for websockets"""
83+
pass
84+
85+
@no_type_check
86+
def open(self, *args, **kwargs):
87+
self.log.debug("Opening websocket %s", self.request.path)
88+
89+
# start the pinging
90+
if self.ping_interval > 0:
91+
loop = ioloop.IOLoop.current()
92+
self.last_ping = loop.time() # Remember time of last ping
93+
self.last_pong = self.last_ping
94+
self.ping_callback = ioloop.PeriodicCallback(
95+
self.send_ping,
96+
self.ping_interval,
97+
)
98+
self.ping_callback.start()
99+
return super().open(*args, **kwargs)
100+
101+
@no_type_check
102+
def send_ping(self):
103+
"""send a ping to keep the websocket alive"""
104+
if self.ws_connection is None and self.ping_callback is not None:
105+
self.ping_callback.stop()
106+
return
107+
108+
if self.ws_connection.client_terminated:
109+
self.close()
110+
return
111+
112+
# check for timeout on pong. Make sure that we really have sent a recent ping in
113+
# case the machine with both server and client has been suspended since the last ping.
114+
now = ioloop.IOLoop.current().time()
115+
since_last_pong = 1e3 * (now - self.last_pong)
116+
since_last_ping = 1e3 * (now - self.last_ping)
117+
if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout:
118+
self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong)
119+
self.close()
120+
return
121+
122+
self.ping(b"")
123+
self.last_ping = now
124+
125+
def on_pong(self, data):
126+
self.last_pong = ioloop.IOLoop.current().time()

0 commit comments

Comments
 (0)