Skip to content

Commit 89f35d4

Browse files
committed
Merge #5499 revised fix: split WS auth (listen vs rate-limited)
2 parents b2eaeae + 62bc0c6 commit 89f35d4

File tree

3 files changed

+100
-58
lines changed

3 files changed

+100
-58
lines changed

backend/routers/transcribe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2483,7 +2483,7 @@ async def _listen(
24832483
@router.websocket("/v4/listen")
24842484
async def listen_handler(
24852485
websocket: WebSocket,
2486-
uid: str = Depends(auth.get_current_user_uid),
2486+
uid: str = Depends(auth.get_current_user_uid_ws_listen),
24872487
language: str = 'en',
24882488
sample_rate: int = 8000,
24892489
codec: str = 'pcm8',

backend/tests/unit/test_ws_auth_handshake.py

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Tests for WebSocket auth handshake fix (#5447).
22
33
Verifies that:
4-
1. WebSocket endpoints send proper close frames on auth failure (not HTTPException)
5-
2. Per-UID rate limiting blocks retry storms
4+
1. get_current_user_uid_ws_listen sends proper close frames on auth failure (no rate limiter)
5+
2. get_current_user_uid_ws adds per-UID rate limiting on top of auth
66
3. /v4/web/listen is NOT affected (uses accept-first pattern)
77
"""
88

@@ -15,114 +15,142 @@
1515
from firebase_admin.auth import InvalidIdTokenError
1616
from starlette.websockets import WebSocketDisconnect
1717

18-
from utils.other.endpoints import get_current_user_uid_ws, get_current_user_uid
18+
from utils.other.endpoints import get_current_user_uid_ws_listen, get_current_user_uid_ws, get_current_user_uid
1919

2020

21-
class TestWebSocketAuthDependency(unittest.TestCase):
22-
"""Test that get_current_user_uid_ws raises WebSocketException instead of HTTPException."""
21+
class TestWebSocketAuthListen(unittest.TestCase):
22+
"""Test get_current_user_uid_ws_listen — auth-only, no rate limiter (used by /v4/listen)."""
2323

2424
def setUp(self):
2525
self.app = FastAPI()
2626

27-
@self.app.websocket("/ws-new")
28-
async def ws_new(websocket: WebSocket, uid: str = Depends(get_current_user_uid_ws)):
29-
await websocket.accept()
30-
await websocket.send_json({"uid": uid})
31-
await websocket.close()
32-
33-
@self.app.websocket("/ws-old")
34-
async def ws_old(websocket: WebSocket, uid: str = Depends(get_current_user_uid)):
27+
@self.app.websocket("/ws-listen")
28+
async def ws_listen(websocket: WebSocket, uid: str = Depends(get_current_user_uid_ws_listen)):
3529
await websocket.accept()
3630
await websocket.send_json({"uid": uid})
3731
await websocket.close()
3832

3933
self.client = TestClient(self.app)
4034

41-
def test_ws_new_no_auth_header_sends_close_1008(self):
35+
def test_no_auth_header_sends_close_1008(self):
4236
"""No auth header -> WebSocketDisconnect with code 1008."""
4337
with self.assertRaises(WebSocketDisconnect) as ctx:
44-
with self.client.websocket_connect("/ws-new"):
38+
with self.client.websocket_connect("/ws-listen"):
4539
self.fail("Expected WebSocket to be closed by server")
4640
self.assertEqual(ctx.exception.code, 1008)
4741

4842
@patch('utils.other.endpoints.verify_token', side_effect=InvalidIdTokenError('Token expired'))
49-
def test_ws_new_invalid_token_sends_close_1008(self, mock_verify):
43+
def test_invalid_token_sends_close_1008(self, mock_verify):
5044
"""Invalid token -> WebSocketDisconnect with code 1008."""
5145
with self.assertRaises(WebSocketDisconnect) as ctx:
52-
with self.client.websocket_connect("/ws-new", headers={"Authorization": "Bearer invalid_token"}):
46+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer invalid_token"}):
5347
self.fail("Expected WebSocket to be closed by server")
5448
self.assertEqual(ctx.exception.code, 1008)
5549

56-
def test_ws_new_malformed_auth_header_sends_close_1008(self):
50+
def test_malformed_auth_header_sends_close_1008(self):
5751
"""Malformed auth header -> WebSocketDisconnect with code 1008."""
5852
with self.assertRaises(WebSocketDisconnect) as ctx:
59-
with self.client.websocket_connect("/ws-new", headers={"Authorization": "malformed"}):
53+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "malformed"}):
6054
self.fail("Expected WebSocket to be closed by server")
6155
self.assertEqual(ctx.exception.code, 1008)
6256

57+
@patch('utils.other.endpoints.verify_token', return_value='test-uid-123')
58+
def test_valid_token_connects(self, mock_verify):
59+
"""Valid token -> successful connection (no rate limiter involved)."""
60+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer valid_token"}) as ws:
61+
data = ws.receive_json()
62+
self.assertEqual(data["uid"], "test-uid-123")
63+
mock_verify.assert_called_once_with("valid_token")
64+
65+
def test_empty_bearer_token_sends_close_1008(self):
66+
"""Authorization: 'Bearer ' (empty token) -> close with 1008."""
67+
with self.assertRaises(WebSocketDisconnect) as ctx:
68+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer "}):
69+
pass
70+
self.assertEqual(ctx.exception.code, 1008)
71+
72+
@patch('utils.other.endpoints.verify_token', side_effect=RuntimeError('unexpected error'))
73+
def test_unexpected_verify_error_sends_close_1008(self, mock_verify):
74+
"""Unexpected error from verify_token -> close with 1008, not handshake crash."""
75+
with self.assertRaises(WebSocketDisconnect) as ctx:
76+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer token"}):
77+
self.fail("Expected connection to fail")
78+
self.assertEqual(ctx.exception.code, 1008)
79+
80+
@patch('utils.other.endpoints.try_acquire_listen_lock')
81+
@patch('utils.other.endpoints.verify_token', return_value='test-uid-123')
82+
def test_no_rate_limiter_called(self, mock_verify, mock_lock):
83+
"""get_current_user_uid_ws_listen must NOT call the rate limiter."""
84+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer valid_token"}) as ws:
85+
data = ws.receive_json()
86+
self.assertEqual(data["uid"], "test-uid-123")
87+
mock_lock.assert_not_called()
88+
89+
90+
class TestWebSocketAuthWithRateLimit(unittest.TestCase):
91+
"""Test get_current_user_uid_ws — auth + rate limiting."""
92+
93+
def setUp(self):
94+
self.app = FastAPI()
95+
96+
@self.app.websocket("/ws-ratelimited")
97+
async def ws_ratelimited(websocket: WebSocket, uid: str = Depends(get_current_user_uid_ws)):
98+
await websocket.accept()
99+
await websocket.send_json({"uid": uid})
100+
await websocket.close()
101+
102+
self.client = TestClient(self.app)
103+
63104
@patch('utils.other.endpoints.try_acquire_listen_lock', return_value=True)
64105
@patch('utils.other.endpoints.verify_token', return_value='test-uid-123')
65-
def test_ws_new_valid_token_connects(self, mock_verify, mock_lock):
106+
def test_valid_token_and_lock_connects(self, mock_verify, mock_lock):
66107
"""Valid token + rate limit available -> successful connection."""
67-
with self.client.websocket_connect("/ws-new", headers={"Authorization": "Bearer valid_token"}) as ws:
108+
with self.client.websocket_connect("/ws-ratelimited", headers={"Authorization": "Bearer valid_token"}) as ws:
68109
data = ws.receive_json()
69110
self.assertEqual(data["uid"], "test-uid-123")
70111
mock_verify.assert_called_once_with("valid_token")
71112
mock_lock.assert_called_once_with("test-uid-123")
72113

73114
@patch('utils.other.endpoints.try_acquire_listen_lock', return_value=False)
74115
@patch('utils.other.endpoints.verify_token', return_value='test-uid-456')
75-
def test_ws_new_rate_limited_sends_close_1008(self, mock_verify, mock_lock):
116+
def test_rate_limited_sends_close_1008(self, mock_verify, mock_lock):
76117
"""Valid token but rate limited -> WebSocketDisconnect with code 1008."""
77118
with self.assertRaises(WebSocketDisconnect) as ctx:
78-
with self.client.websocket_connect("/ws-new", headers={"Authorization": "Bearer valid_token"}):
119+
with self.client.websocket_connect(
120+
"/ws-ratelimited", headers={"Authorization": "Bearer valid_token"}
121+
):
79122
self.fail("Expected WebSocket to be closed due to rate limit")
80123
self.assertEqual(ctx.exception.code, 1008)
81124
mock_verify.assert_called_once_with("valid_token")
82125
mock_lock.assert_called_once_with("test-uid-456")
83126

84127
@patch('utils.other.endpoints.try_acquire_listen_lock', side_effect=ConnectionError('redis down'))
85128
@patch('utils.other.endpoints.verify_token', return_value='test-uid-789')
86-
def test_ws_new_redis_failure_fails_open(self, mock_verify, mock_lock):
129+
def test_redis_failure_fails_open(self, mock_verify, mock_lock):
87130
"""Redis failure in rate limiter -> fail-open, connection proceeds."""
88-
with self.client.websocket_connect("/ws-new", headers={"Authorization": "Bearer valid_token"}) as ws:
131+
with self.client.websocket_connect("/ws-ratelimited", headers={"Authorization": "Bearer valid_token"}) as ws:
89132
data = ws.receive_json()
90133
self.assertEqual(data["uid"], "test-uid-789")
91134

92135
@patch('utils.other.endpoints.try_acquire_listen_lock')
93-
def test_ws_new_no_auth_does_not_call_rate_limiter(self, mock_lock):
136+
def test_no_auth_does_not_call_rate_limiter(self, mock_lock):
94137
"""Missing auth header should short-circuit before rate limiter is called."""
95138
with self.assertRaises(WebSocketDisconnect) as ctx:
96-
with self.client.websocket_connect("/ws-new"):
139+
with self.client.websocket_connect("/ws-ratelimited"):
97140
pass
98141
self.assertEqual(ctx.exception.code, 1008)
99142
mock_lock.assert_not_called()
100143

101144
@patch('utils.other.endpoints.try_acquire_listen_lock')
102145
@patch('utils.other.endpoints.verify_token', side_effect=InvalidIdTokenError('expired'))
103-
def test_ws_new_invalid_token_does_not_call_rate_limiter(self, mock_verify, mock_lock):
146+
def test_invalid_token_does_not_call_rate_limiter(self, mock_verify, mock_lock):
104147
"""Invalid token should short-circuit before rate limiter is called."""
105148
with self.assertRaises(WebSocketDisconnect) as ctx:
106-
with self.client.websocket_connect("/ws-new", headers={"Authorization": "Bearer bad"}):
149+
with self.client.websocket_connect("/ws-ratelimited", headers={"Authorization": "Bearer bad"}):
107150
pass
108151
self.assertEqual(ctx.exception.code, 1008)
109152
mock_lock.assert_not_called()
110153

111-
def test_ws_new_empty_bearer_token_sends_close_1008(self):
112-
"""Authorization: 'Bearer ' (empty token) -> close with 1008."""
113-
with self.assertRaises(WebSocketDisconnect) as ctx:
114-
with self.client.websocket_connect("/ws-new", headers={"Authorization": "Bearer "}):
115-
pass
116-
self.assertEqual(ctx.exception.code, 1008)
117-
118-
@patch('utils.other.endpoints.verify_token', side_effect=RuntimeError('unexpected error'))
119-
def test_ws_new_unexpected_verify_error_sends_close_1008(self, mock_verify):
120-
"""Unexpected error from verify_token -> close with 1008, not handshake crash."""
121-
with self.assertRaises(WebSocketDisconnect) as ctx:
122-
with self.client.websocket_connect("/ws-new", headers={"Authorization": "Bearer token"}):
123-
self.fail("Expected connection to fail")
124-
self.assertEqual(ctx.exception.code, 1008)
125-
126154

127155
class TestWebSocketCloseFrameBehavior(unittest.TestCase):
128156
"""Test that WebSocketException actually sends ASGI close message (vs HTTPException which doesn't)."""
@@ -232,7 +260,7 @@ async def send(msg):
232260

233261

234262
class TestListenEndpointNotAffectWebListen(unittest.TestCase):
235-
"""Verify /v4/listen uses WS auth and /v4/web/listen is unchanged (source-level check)."""
263+
"""Verify /v4/listen uses WS auth (no rate limiter) and /v4/web/listen is unchanged (source-level check)."""
236264

237265
def _read_transcribe_source(self):
238266
import os
@@ -241,8 +269,8 @@ def _read_transcribe_source(self):
241269
with open(path) as f:
242270
return f.read()
243271

244-
def test_listen_handler_uses_http_auth_dependency(self):
245-
"""listen_handler should use get_current_user_uid (HTTP variant) — mobile app sends Authorization header."""
272+
def test_listen_handler_uses_ws_listen_auth(self):
273+
"""listen_handler should use get_current_user_uid_ws_listen (WS auth, no rate limiter)."""
246274
source = self._read_transcribe_source()
247275
import re
248276

@@ -253,9 +281,8 @@ def test_listen_handler_uses_http_auth_dependency(self):
253281
)
254282
self.assertIsNotNone(listen_match, "Could not find /v4/listen handler")
255283
handler_sig = listen_match.group()
256-
self.assertIn('get_current_user_uid)', handler_sig, "/v4/listen must use get_current_user_uid")
257-
self.assertNotIn(
258-
'get_current_user_uid_ws', handler_sig, "/v4/listen must NOT use get_current_user_uid_ws"
284+
self.assertIn(
285+
'get_current_user_uid_ws_listen', handler_sig, "/v4/listen must use get_current_user_uid_ws_listen"
259286
)
260287

261288
def test_web_listen_has_no_uid_dependency(self):

backend/utils/other/endpoints.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ def get_current_user_uid(authorization: str = Header(None)):
6161
raise HTTPException(status_code=401, detail="Invalid authorization token")
6262

6363

64-
def get_current_user_uid_ws(authorization: str = Header(None)):
65-
"""FastAPI dependency for WebSocket endpoints with Authorization header.
64+
def _verify_ws_auth(authorization: str) -> str:
65+
"""Common WebSocket auth — verifies token, returns uid.
6666
67-
Unlike get_current_user_uid, raises WebSocketException(code=1008) instead of
68-
HTTPException(401). This ensures the ASGI server sends a proper WebSocket close
69-
frame instead of exiting without a handshake (which causes LB 5xx).
67+
Raises WebSocketException(code=1008) instead of HTTPException(401) so the
68+
ASGI server sends a proper WebSocket close frame (not a handshake crash).
7069
"""
7170
if not authorization:
7271
raise WebSocketException(code=1008, reason="Authorization header not found")
@@ -75,15 +74,31 @@ def get_current_user_uid_ws(authorization: str = Header(None)):
7574

7675
try:
7776
token = authorization.split(' ')[1]
78-
uid = verify_token(token)
77+
return verify_token(token)
7978
except InvalidIdTokenError as e:
8079
logger.error(f"WebSocket auth failed: {e}")
8180
raise WebSocketException(code=1008, reason="Invalid or expired token")
8281
except Exception as e:
8382
logger.error(f"WebSocket auth error: {e}")
8483
raise WebSocketException(code=1008, reason="Auth error")
8584

86-
# Per-UID connection rate limiting (7s window) to prevent retry storms
85+
86+
def get_current_user_uid_ws_listen(authorization: str = Header(None)):
87+
"""WebSocket auth for /v4/listen — NO rate limiting.
88+
89+
Mobile apps reconnect legitimately on network switch / backgrounding,
90+
so the per-UID rate limiter must not block them.
91+
"""
92+
return _verify_ws_auth(authorization)
93+
94+
95+
def get_current_user_uid_ws(authorization: str = Header(None)):
96+
"""WebSocket auth WITH per-UID rate limiting (7s window).
97+
98+
Use for WebSocket endpoints that need retry-storm protection.
99+
"""
100+
uid = _verify_ws_auth(authorization)
101+
87102
# Fail-open on Redis errors to avoid reintroducing handshake crashes
88103
try:
89104
if not try_acquire_listen_lock(uid):

0 commit comments

Comments
 (0)