Skip to content

Commit 029be08

Browse files
authored
Implement websocket keepalive pings for websockets-sansio (#2888)
1 parent 8d397c7 commit 029be08

File tree

3 files changed

+167
-3
lines changed

3 files changed

+167
-3
lines changed

docs/settings.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ Using Uvicorn with watchfiles will enable the following options (which are other
9595
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*.
9696
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. Only available with the `websockets` protocol. **Default:** *16777216* (16 MB).
9797
* `--ws-max-queue <int>` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*.
98-
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*.
99-
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*.
98+
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
99+
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
100100
* `--ws-per-message-deflate <bool>` - Enable/disable WebSocket per-message-deflate compression. Only available with the `websockets` protocol. **Default:** *True*.
101101
* `--lifespan <str>` - Set the Lifespan protocol implementation. **Options:** *'auto', 'on', 'off'.* **Default:** *'auto'*.
102102
* `--h11-max-incomplete-event-size <int>` - Set the maximum number of bytes to buffer of an incomplete event. Only available for `h11` HTTP protocol implementation. **Default:** *16384* (16 KB).

tests/protocols/test_websocket.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from uvicorn.config import Config
2929
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
30+
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
3031

3132
try:
3233
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol as _WSProtocol
@@ -1202,3 +1203,90 @@ async def app_wrapper(scope: Scope, receive: ASGIReceiveCallable, send: ASGISend
12021203
assert is_open
12031204

12041205
assert expected_states == actual_states
1206+
1207+
1208+
async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
1209+
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
1210+
while True:
1211+
message = await receive()
1212+
if message["type"] == "websocket.connect":
1213+
await send({"type": "websocket.accept"})
1214+
elif message["type"] == "websocket.disconnect":
1215+
break
1216+
1217+
config = Config(
1218+
app=app,
1219+
ws=WebSocketsSansIOProtocol,
1220+
http=http_protocol_cls,
1221+
lifespan="off",
1222+
ws_ping_interval=0.1,
1223+
ws_ping_timeout=5.0,
1224+
port=unused_tcp_port,
1225+
)
1226+
async with run_server(config) as server:
1227+
# The websockets client auto-responds to ping frames, keeping the connection alive.
1228+
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
1229+
protocol = list(server.server_state.connections)[0]
1230+
assert isinstance(protocol, WebSocketsSansIOProtocol)
1231+
1232+
# Wait until at least one ping/pong roundtrip completes.
1233+
async def ping_roundtrip() -> None:
1234+
while protocol.last_ping_rtt == 0.0:
1235+
await asyncio.sleep(0.1)
1236+
1237+
await asyncio.wait_for(ping_roundtrip(), timeout=5.0)
1238+
assert protocol.last_ping_rtt > 0
1239+
1240+
1241+
async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
1242+
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
1243+
while True:
1244+
message = await receive()
1245+
if message["type"] == "websocket.connect":
1246+
await send({"type": "websocket.accept"})
1247+
elif message["type"] == "websocket.disconnect":
1248+
break
1249+
1250+
config = Config(
1251+
app=app,
1252+
ws=WebSocketsSansIOProtocol,
1253+
http=http_protocol_cls,
1254+
lifespan="off",
1255+
ws_ping_interval=0.1,
1256+
ws_ping_timeout=0.1,
1257+
log_level="trace",
1258+
port=unused_tcp_port,
1259+
)
1260+
async with run_server(config):
1261+
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None) as websocket:
1262+
# Swallow outgoing pong frames so the server's ping never gets ack'd.
1263+
websocket.transport.write = lambda data: None # type: ignore[method-assign]
1264+
with pytest.raises(websockets.exceptions.ConnectionClosedError) as exc_info:
1265+
await asyncio.wait_for(websocket.recv(), timeout=1)
1266+
assert exc_info.value.rcvd is not None
1267+
assert exc_info.value.rcvd.code == 1011
1268+
assert exc_info.value.rcvd.reason == "keepalive ping timeout"
1269+
1270+
1271+
async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
1272+
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
1273+
while True:
1274+
message = await receive()
1275+
if message["type"] == "websocket.connect":
1276+
await send({"type": "websocket.accept"})
1277+
elif message["type"] == "websocket.disconnect":
1278+
break
1279+
1280+
config = Config(
1281+
app=app,
1282+
ws=WebSocketsSansIOProtocol,
1283+
http=http_protocol_cls,
1284+
lifespan="off",
1285+
ws_ping_interval=None,
1286+
port=unused_tcp_port,
1287+
)
1288+
async with run_server(config) as server:
1289+
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
1290+
protocol = list(server.server_state.connections)[0]
1291+
assert isinstance(protocol, WebSocketsSansIOProtocol)
1292+
assert protocol.ping_timer is None

uvicorn/protocols/websockets/websockets_sansio_impl.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import asyncio
44
import logging
5+
import random
6+
import struct
57
import sys
8+
from asyncio import TimerHandle
69
from asyncio.transports import BaseTransport, Transport
710
from http import HTTPStatus
811
from typing import Any, Literal, cast
@@ -92,6 +95,15 @@ def __init__(
9295
self.writable = asyncio.Event()
9396
self.writable.set()
9497

98+
# Keepalive state
99+
self.ping_interval = config.ws_ping_interval
100+
self.ping_timeout = config.ws_ping_timeout
101+
self.ping_timer: TimerHandle | None = None
102+
self.pong_timer: TimerHandle | None = None
103+
self.pending_ping_payload: bytes | None = None
104+
self.ping_sent_at: float = 0.0
105+
self.last_ping_rtt: float = 0.0
106+
95107
# Buffers
96108
self.bytes = b""
97109

@@ -109,6 +121,7 @@ def connection_made(self, transport: BaseTransport) -> None:
109121
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
110122

111123
def connection_lost(self, exc: Exception | None) -> None:
124+
self.stop_keepalive()
112125
code = 1005 if self.handshake_complete else 1006
113126
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
114127
self.connections.remove(self)
@@ -125,6 +138,7 @@ def eof_received(self) -> None:
125138
pass
126139

127140
def shutdown(self) -> None:
141+
self.stop_keepalive()
128142
if self.handshake_complete:
129143
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
130144
self.conn.send_close(1012)
@@ -155,7 +169,7 @@ def handle_events(self) -> None:
155169
elif event.opcode == Opcode.PING:
156170
self.handle_ping()
157171
elif event.opcode == Opcode.PONG:
158-
pass # pragma: no cover
172+
self.handle_pong(event)
159173
elif event.opcode == Opcode.CLOSE:
160174
self.handle_close(event)
161175
else:
@@ -238,6 +252,67 @@ def handle_ping(self) -> None:
238252
output = self.conn.data_to_send()
239253
self.transport.write(b"".join(output))
240254

255+
def handle_pong(self, event: Frame) -> None:
256+
# Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight
257+
if self.pending_ping_payload is None or bytes(event.data) != self.pending_ping_payload:
258+
return # pragma: no cover
259+
260+
self.last_ping_rtt = self.loop.time() - self.ping_sent_at
261+
self.pending_ping_payload = None
262+
# The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is
263+
# what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is
264+
# already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here.
265+
if self.pong_timer is not None:
266+
self.pong_timer.cancel()
267+
self.pong_timer = None
268+
self.schedule_ping()
269+
270+
def start_keepalive(self) -> None:
271+
if self.ping_interval is not None and self.ping_interval > 0:
272+
self.schedule_ping()
273+
274+
def stop_keepalive(self) -> None:
275+
if self.ping_timer is not None:
276+
self.ping_timer.cancel()
277+
self.ping_timer = None
278+
if self.pong_timer is not None: # pragma: no cover
279+
self.pong_timer.cancel()
280+
self.pong_timer = None
281+
self.pending_ping_payload = None
282+
283+
def schedule_ping(self) -> None:
284+
assert self.ping_interval is not None
285+
delay = max(0.0, self.ping_interval - self.last_ping_rtt)
286+
self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping)
287+
288+
def send_keepalive_ping(self) -> None:
289+
self.ping_timer = None
290+
if self.close_sent or self.transport.is_closing(): # pragma: no cover
291+
return
292+
# Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs.
293+
# See https://github.com/python-websockets/websockets/blob/4d229bf9f583d593aa103287aee0a77c9fbc3a79/src/websockets/asyncio/connection.py#L624
294+
self.pending_ping_payload = struct.pack("!I", random.getrandbits(32))
295+
self.ping_sent_at = self.loop.time()
296+
self.conn.send_ping(self.pending_ping_payload)
297+
self.transport.write(b"".join(self.conn.data_to_send()))
298+
if self.ping_timeout is not None:
299+
self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout)
300+
else: # pragma: no cover
301+
self.schedule_ping()
302+
303+
def keepalive_timeout(self) -> None:
304+
self.pong_timer = None
305+
self.pending_ping_payload = None
306+
if self.close_sent or self.transport.is_closing(): # pragma: no cover
307+
return
308+
if self.logger.level <= TRACE_LOG_LEVEL:
309+
prefix = "%s:%d - " % self.client if self.client else ""
310+
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix)
311+
self.conn.fail(1011, "keepalive ping timeout")
312+
self.transport.write(b"".join(self.conn.data_to_send()))
313+
self.close_sent = True
314+
self.transport.close()
315+
241316
def handle_close(self, event: Frame) -> None:
242317
if not self.close_sent and not self.transport.is_closing():
243318
assert self.conn.close_rcvd is not None
@@ -311,6 +386,7 @@ async def send(self, message: ASGISendEvent) -> None:
311386
self.conn.send_response(self.response)
312387
output = self.conn.data_to_send()
313388
self.transport.write(b"".join(output))
389+
self.start_keepalive()
314390

315391
elif message["type"] == "websocket.close":
316392
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})

0 commit comments

Comments
 (0)