|
| 1 | +import re |
| 2 | +from typing import Optional, no_type_check |
| 3 | +from urllib.parse import urlparse |
| 4 | + |
| 5 | +from tornado import ioloop |
| 6 | +from tornado.iostream import IOStream |
| 7 | + |
| 8 | +# ping interval for keeping websockets alive (30 seconds) |
| 9 | +WS_PING_INTERVAL = 30000 |
| 10 | + |
| 11 | + |
| 12 | +class WebSocketMixin: |
| 13 | + """Mixin for common websocket options""" |
| 14 | + |
| 15 | + ping_callback = None |
| 16 | + last_ping = 0.0 |
| 17 | + last_pong = 0.0 |
| 18 | + stream = None # type: Optional[IOStream] |
| 19 | + |
| 20 | + @property |
| 21 | + def ping_interval(self): |
| 22 | + """The interval for websocket keep-alive pings. |
| 23 | +
|
| 24 | + Set ws_ping_interval = 0 to disable pings. |
| 25 | + """ |
| 26 | + return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined] |
| 27 | + |
| 28 | + @property |
| 29 | + def ping_timeout(self): |
| 30 | + """If no ping is received in this many milliseconds, |
| 31 | + close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). |
| 32 | + Default is max of 3 pings or 30 seconds. |
| 33 | + """ |
| 34 | + return self.settings.get( # type:ignore[attr-defined] |
| 35 | + "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL) |
| 36 | + ) |
| 37 | + |
| 38 | + @no_type_check |
| 39 | + def check_origin(self, origin: Optional[str] = None) -> bool: |
| 40 | + """Check Origin == Host or Access-Control-Allow-Origin. |
| 41 | +
|
| 42 | + Tornado >= 4 calls this method automatically, raising 403 if it returns False. |
| 43 | + """ |
| 44 | + |
| 45 | + if self.allow_origin == "*" or ( |
| 46 | + hasattr(self, "skip_check_origin") and self.skip_check_origin() |
| 47 | + ): |
| 48 | + return True |
| 49 | + |
| 50 | + host = self.request.headers.get("Host") |
| 51 | + if origin is None: |
| 52 | + origin = self.get_origin() |
| 53 | + |
| 54 | + # If no origin or host header is provided, assume from script |
| 55 | + if origin is None or host is None: |
| 56 | + return True |
| 57 | + |
| 58 | + origin = origin.lower() |
| 59 | + origin_host = urlparse(origin).netloc |
| 60 | + |
| 61 | + # OK if origin matches host |
| 62 | + if origin_host == host: |
| 63 | + return True |
| 64 | + |
| 65 | + # Check CORS headers |
| 66 | + if self.allow_origin: |
| 67 | + allow = self.allow_origin == origin |
| 68 | + elif self.allow_origin_pat: |
| 69 | + allow = bool(re.match(self.allow_origin_pat, origin)) |
| 70 | + else: |
| 71 | + # No CORS headers deny the request |
| 72 | + allow = False |
| 73 | + if not allow: |
| 74 | + self.log.warning( |
| 75 | + "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s", |
| 76 | + origin, |
| 77 | + host, |
| 78 | + ) |
| 79 | + return allow |
| 80 | + |
| 81 | + def clear_cookie(self, *args, **kwargs): |
| 82 | + """meaningless for websockets""" |
| 83 | + pass |
| 84 | + |
| 85 | + @no_type_check |
| 86 | + def open(self, *args, **kwargs): |
| 87 | + self.log.debug("Opening websocket %s", self.request.path) |
| 88 | + |
| 89 | + # start the pinging |
| 90 | + if self.ping_interval > 0: |
| 91 | + loop = ioloop.IOLoop.current() |
| 92 | + self.last_ping = loop.time() # Remember time of last ping |
| 93 | + self.last_pong = self.last_ping |
| 94 | + self.ping_callback = ioloop.PeriodicCallback( |
| 95 | + self.send_ping, |
| 96 | + self.ping_interval, |
| 97 | + ) |
| 98 | + self.ping_callback.start() |
| 99 | + return super().open(*args, **kwargs) |
| 100 | + |
| 101 | + @no_type_check |
| 102 | + def send_ping(self): |
| 103 | + """send a ping to keep the websocket alive""" |
| 104 | + if self.ws_connection is None and self.ping_callback is not None: |
| 105 | + self.ping_callback.stop() |
| 106 | + return |
| 107 | + |
| 108 | + if self.ws_connection.client_terminated: |
| 109 | + self.close() |
| 110 | + return |
| 111 | + |
| 112 | + # check for timeout on pong. Make sure that we really have sent a recent ping in |
| 113 | + # case the machine with both server and client has been suspended since the last ping. |
| 114 | + now = ioloop.IOLoop.current().time() |
| 115 | + since_last_pong = 1e3 * (now - self.last_pong) |
| 116 | + since_last_ping = 1e3 * (now - self.last_ping) |
| 117 | + if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout: |
| 118 | + self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong) |
| 119 | + self.close() |
| 120 | + return |
| 121 | + |
| 122 | + self.ping(b"") |
| 123 | + self.last_ping = now |
| 124 | + |
| 125 | + def on_pong(self, data): |
| 126 | + self.last_pong = ioloop.IOLoop.current().time() |
0 commit comments