Skip to content

Commit f378dcc

Browse files
committed
Merge PR #5499: fix(backend-listen): WebSocket auth sends close frame instead of crashing handshake
2 parents abcd2d0 + 35c0b5c commit f378dcc

File tree

4 files changed

+382
-3
lines changed

4 files changed

+382
-3
lines changed

backend/routers/transcribe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from database.redis_db import (
4141
check_credits_invalidation,
4242
get_cached_user_geolocation,
43-
try_acquire_listen_lock,
4443
)
4544
from models.conversation import (
4645
Conversation,
@@ -2693,7 +2692,7 @@ async def _listen(
26932692
@router.websocket("/v4/listen")
26942693
async def listen_handler(
26952694
websocket: WebSocket,
2696-
uid: str = Depends(auth.get_current_user_uid),
2695+
uid: str = Depends(auth.get_current_user_uid_ws_listen),
26972696
language: str = 'en',
26982697
sample_rate: int = 8000,
26992698
codec: str = 'pcm8',

backend/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ pytest tests/unit/test_pusher_private_cloud_data_protection.py -v
4242
pytest tests/unit/test_storage_upload_audio_chunk_data_protection.py -v
4343
pytest tests/unit/test_people_conversations_500s.py -v
4444
pytest tests/unit/test_firestore_read_ops_cache.py -v
45+
pytest tests/unit/test_ws_auth_handshake.py -v
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
"""Tests for WebSocket auth handshake fix (#5447).
2+
3+
Verifies that:
4+
1. get_current_user_uid_ws_listen sends proper close frames on auth failure (no rate limiter)
5+
2. get_current_user_uid_ws adds per-UID rate limiting on top of auth
6+
3. /v4/web/listen is NOT affected (uses accept-first pattern)
7+
"""
8+
9+
import asyncio
10+
import unittest
11+
from unittest.mock import patch, MagicMock
12+
13+
from fastapi import FastAPI, WebSocket, WebSocketException, Depends
14+
from fastapi.testclient import TestClient
15+
from firebase_admin.auth import InvalidIdTokenError
16+
from starlette.websockets import WebSocketDisconnect
17+
18+
from utils.other.endpoints import get_current_user_uid_ws_listen, get_current_user_uid_ws, get_current_user_uid
19+
20+
21+
class TestWebSocketAuthListen(unittest.TestCase):
22+
"""Test get_current_user_uid_ws_listen — auth-only, no rate limiter (used by /v4/listen)."""
23+
24+
def setUp(self):
25+
self.app = FastAPI()
26+
27+
@self.app.websocket("/ws-listen")
28+
async def ws_listen(websocket: WebSocket, uid: str = Depends(get_current_user_uid_ws_listen)):
29+
await websocket.accept()
30+
await websocket.send_json({"uid": uid})
31+
await websocket.close()
32+
33+
self.client = TestClient(self.app)
34+
35+
def test_no_auth_header_sends_close_1008(self):
36+
"""No auth header -> WebSocketDisconnect with code 1008."""
37+
with self.assertRaises(WebSocketDisconnect) as ctx:
38+
with self.client.websocket_connect("/ws-listen"):
39+
self.fail("Expected WebSocket to be closed by server")
40+
self.assertEqual(ctx.exception.code, 1008)
41+
42+
@patch('utils.other.endpoints.verify_token', side_effect=InvalidIdTokenError('Token expired'))
43+
def test_invalid_token_sends_close_1008(self, mock_verify):
44+
"""Invalid token -> WebSocketDisconnect with code 1008."""
45+
with self.assertRaises(WebSocketDisconnect) as ctx:
46+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer invalid_token"}):
47+
self.fail("Expected WebSocket to be closed by server")
48+
self.assertEqual(ctx.exception.code, 1008)
49+
50+
def test_malformed_auth_header_sends_close_1008(self):
51+
"""Malformed auth header -> WebSocketDisconnect with code 1008."""
52+
with self.assertRaises(WebSocketDisconnect) as ctx:
53+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "malformed"}):
54+
self.fail("Expected WebSocket to be closed by server")
55+
self.assertEqual(ctx.exception.code, 1008)
56+
57+
@patch('utils.other.endpoints.verify_token', return_value='test-uid-123')
58+
def test_valid_token_connects(self, mock_verify):
59+
"""Valid token -> successful connection (no rate limiter involved)."""
60+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer valid_token"}) as ws:
61+
data = ws.receive_json()
62+
self.assertEqual(data["uid"], "test-uid-123")
63+
mock_verify.assert_called_once_with("valid_token")
64+
65+
def test_empty_bearer_token_sends_close_1008(self):
66+
"""Authorization: 'Bearer ' (empty token) -> close with 1008."""
67+
with self.assertRaises(WebSocketDisconnect) as ctx:
68+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer "}):
69+
pass
70+
self.assertEqual(ctx.exception.code, 1008)
71+
72+
@patch('utils.other.endpoints.verify_token', side_effect=RuntimeError('unexpected error'))
73+
def test_unexpected_verify_error_sends_close_1008(self, mock_verify):
74+
"""Unexpected error from verify_token -> close with 1008, not handshake crash."""
75+
with self.assertRaises(WebSocketDisconnect) as ctx:
76+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer token"}):
77+
self.fail("Expected connection to fail")
78+
self.assertEqual(ctx.exception.code, 1008)
79+
80+
@patch('utils.other.endpoints.try_acquire_listen_lock')
81+
@patch('utils.other.endpoints.verify_token', return_value='test-uid-123')
82+
def test_no_rate_limiter_called(self, mock_verify, mock_lock):
83+
"""get_current_user_uid_ws_listen must NOT call the rate limiter."""
84+
with self.client.websocket_connect("/ws-listen", headers={"Authorization": "Bearer valid_token"}) as ws:
85+
data = ws.receive_json()
86+
self.assertEqual(data["uid"], "test-uid-123")
87+
mock_lock.assert_not_called()
88+
89+
90+
class TestWebSocketAuthWithRateLimit(unittest.TestCase):
91+
"""Test get_current_user_uid_ws — auth + rate limiting."""
92+
93+
def setUp(self):
94+
self.app = FastAPI()
95+
96+
@self.app.websocket("/ws-ratelimited")
97+
async def ws_ratelimited(websocket: WebSocket, uid: str = Depends(get_current_user_uid_ws)):
98+
await websocket.accept()
99+
await websocket.send_json({"uid": uid})
100+
await websocket.close()
101+
102+
self.client = TestClient(self.app)
103+
104+
@patch('utils.other.endpoints.try_acquire_listen_lock', return_value=True)
105+
@patch('utils.other.endpoints.verify_token', return_value='test-uid-123')
106+
def test_valid_token_and_lock_connects(self, mock_verify, mock_lock):
107+
"""Valid token + rate limit available -> successful connection."""
108+
with self.client.websocket_connect("/ws-ratelimited", headers={"Authorization": "Bearer valid_token"}) as ws:
109+
data = ws.receive_json()
110+
self.assertEqual(data["uid"], "test-uid-123")
111+
mock_verify.assert_called_once_with("valid_token")
112+
mock_lock.assert_called_once_with("test-uid-123")
113+
114+
@patch('utils.other.endpoints.try_acquire_listen_lock', return_value=False)
115+
@patch('utils.other.endpoints.verify_token', return_value='test-uid-456')
116+
def test_rate_limited_sends_close_1008(self, mock_verify, mock_lock):
117+
"""Valid token but rate limited -> WebSocketDisconnect with code 1008."""
118+
with self.assertRaises(WebSocketDisconnect) as ctx:
119+
with self.client.websocket_connect(
120+
"/ws-ratelimited", headers={"Authorization": "Bearer valid_token"}
121+
):
122+
self.fail("Expected WebSocket to be closed due to rate limit")
123+
self.assertEqual(ctx.exception.code, 1008)
124+
mock_verify.assert_called_once_with("valid_token")
125+
mock_lock.assert_called_once_with("test-uid-456")
126+
127+
@patch('utils.other.endpoints.try_acquire_listen_lock', side_effect=ConnectionError('redis down'))
128+
@patch('utils.other.endpoints.verify_token', return_value='test-uid-789')
129+
def test_redis_failure_fails_open(self, mock_verify, mock_lock):
130+
"""Redis failure in rate limiter -> fail-open, connection proceeds."""
131+
with self.client.websocket_connect("/ws-ratelimited", headers={"Authorization": "Bearer valid_token"}) as ws:
132+
data = ws.receive_json()
133+
self.assertEqual(data["uid"], "test-uid-789")
134+
135+
def test_malformed_auth_header_sends_close_1008(self):
136+
"""Malformed auth header -> WebSocketDisconnect with code 1008 (via shared _verify_ws_auth)."""
137+
with self.assertRaises(WebSocketDisconnect) as ctx:
138+
with self.client.websocket_connect("/ws-ratelimited", headers={"Authorization": "malformed"}):
139+
self.fail("Expected WebSocket to be closed by server")
140+
self.assertEqual(ctx.exception.code, 1008)
141+
142+
@patch('utils.other.endpoints.try_acquire_listen_lock', side_effect=WebSocketException(code=1008, reason='lock ws exc'))
143+
@patch('utils.other.endpoints.verify_token', return_value='test-uid-reraise')
144+
def test_ws_exception_from_lock_is_reraised(self, mock_verify, mock_lock):
145+
"""WebSocketException from rate limiter is re-raised, not swallowed by fail-open handler."""
146+
with self.assertRaises(WebSocketDisconnect) as ctx:
147+
with self.client.websocket_connect(
148+
"/ws-ratelimited", headers={"Authorization": "Bearer valid_token"}
149+
):
150+
self.fail("Expected WebSocket to be closed")
151+
self.assertEqual(ctx.exception.code, 1008)
152+
153+
@patch('utils.other.endpoints.try_acquire_listen_lock')
154+
def test_no_auth_does_not_call_rate_limiter(self, mock_lock):
155+
"""Missing auth header should short-circuit before rate limiter is called."""
156+
with self.assertRaises(WebSocketDisconnect) as ctx:
157+
with self.client.websocket_connect("/ws-ratelimited"):
158+
pass
159+
self.assertEqual(ctx.exception.code, 1008)
160+
mock_lock.assert_not_called()
161+
162+
@patch('utils.other.endpoints.try_acquire_listen_lock')
163+
@patch('utils.other.endpoints.verify_token', side_effect=InvalidIdTokenError('expired'))
164+
def test_invalid_token_does_not_call_rate_limiter(self, mock_verify, mock_lock):
165+
"""Invalid token should short-circuit before rate limiter is called."""
166+
with self.assertRaises(WebSocketDisconnect) as ctx:
167+
with self.client.websocket_connect("/ws-ratelimited", headers={"Authorization": "Bearer bad"}):
168+
pass
169+
self.assertEqual(ctx.exception.code, 1008)
170+
mock_lock.assert_not_called()
171+
172+
173+
class TestWebSocketCloseFrameBehavior(unittest.TestCase):
174+
"""Test that WebSocketException actually sends ASGI close message (vs HTTPException which doesn't)."""
175+
176+
def test_ws_exception_sends_close_message(self):
177+
"""Verify WebSocketException sends websocket.close ASGI message."""
178+
from fastapi import WebSocketException
179+
180+
app = FastAPI()
181+
182+
def dep_ws():
183+
raise WebSocketException(code=1008, reason="test rejection")
184+
185+
@app.websocket("/test")
186+
async def handler(ws: WebSocket, _: str = Depends(dep_ws)):
187+
await ws.accept()
188+
189+
sent_messages = []
190+
191+
async def run():
192+
scope = {
193+
'type': 'websocket',
194+
'asgi': {'version': '3.0', 'spec_version': '2.3'},
195+
'http_version': '1.1',
196+
'scheme': 'ws',
197+
'method': 'GET',
198+
'path': '/test',
199+
'raw_path': b'/test',
200+
'query_string': b'',
201+
'root_path': '',
202+
'headers': [],
203+
'client': ('127.0.0.1', 12345),
204+
'server': ('testserver', 80),
205+
'subprotocols': [],
206+
'state': {},
207+
}
208+
recv_events = [{'type': 'websocket.connect'}]
209+
210+
async def receive():
211+
if recv_events:
212+
return recv_events.pop(0)
213+
await asyncio.sleep(3600)
214+
215+
async def send(msg):
216+
sent_messages.append(msg)
217+
218+
await app(scope, receive, send)
219+
220+
asyncio.run(run())
221+
222+
# WebSocketException should produce a websocket.close message
223+
close_messages = [m for m in sent_messages if m.get('type') == 'websocket.close']
224+
self.assertEqual(
225+
len(close_messages), 1, f"Expected 1 close message, got {len(close_messages)}: {sent_messages}"
226+
)
227+
self.assertEqual(close_messages[0]['code'], 1008)
228+
229+
def test_http_exception_sends_no_close_message(self):
230+
"""Verify HTTPException does NOT send any ASGI message (causes LB 5xx)."""
231+
from fastapi import HTTPException
232+
233+
app = FastAPI()
234+
235+
def dep_http():
236+
raise HTTPException(status_code=401, detail="unauthorized")
237+
238+
@app.websocket("/test")
239+
async def handler(ws: WebSocket, _: str = Depends(dep_http)):
240+
await ws.accept()
241+
242+
sent_messages = []
243+
244+
async def run():
245+
scope = {
246+
'type': 'websocket',
247+
'asgi': {'version': '3.0', 'spec_version': '2.3'},
248+
'http_version': '1.1',
249+
'scheme': 'ws',
250+
'method': 'GET',
251+
'path': '/test',
252+
'raw_path': b'/test',
253+
'query_string': b'',
254+
'root_path': '',
255+
'headers': [],
256+
'client': ('127.0.0.1', 12345),
257+
'server': ('testserver', 80),
258+
'subprotocols': [],
259+
'state': {},
260+
}
261+
recv_events = [{'type': 'websocket.connect'}]
262+
263+
async def receive():
264+
if recv_events:
265+
return recv_events.pop(0)
266+
await asyncio.sleep(3600)
267+
268+
async def send(msg):
269+
sent_messages.append(msg)
270+
271+
await app(scope, receive, send)
272+
273+
asyncio.run(run())
274+
275+
# HTTPException should produce NO websocket.close message — this is the bug
276+
close_messages = [m for m in sent_messages if m.get('type') == 'websocket.close']
277+
self.assertEqual(len(close_messages), 0, f"HTTPException should not send close frame, got: {sent_messages}")
278+
279+
280+
class TestListenEndpointNotAffectWebListen(unittest.TestCase):
281+
"""Verify /v4/listen uses WS auth (no rate limiter) and /v4/web/listen is unchanged (source-level check)."""
282+
283+
def _read_transcribe_source(self):
284+
import os
285+
286+
path = os.path.join(os.path.dirname(__file__), '..', '..', 'routers', 'transcribe.py')
287+
with open(path) as f:
288+
return f.read()
289+
290+
def test_listen_handler_uses_ws_listen_auth(self):
291+
"""listen_handler should use get_current_user_uid_ws_listen (WS auth, no rate limiter)."""
292+
source = self._read_transcribe_source()
293+
import re
294+
295+
listen_match = re.search(
296+
r'@router\.websocket\("/v4/listen"\)\s*\nasync def listen_handler\([^)]+\)',
297+
source,
298+
re.DOTALL,
299+
)
300+
self.assertIsNotNone(listen_match, "Could not find /v4/listen handler")
301+
handler_sig = listen_match.group()
302+
self.assertIn(
303+
'get_current_user_uid_ws_listen', handler_sig, "/v4/listen must use get_current_user_uid_ws_listen"
304+
)
305+
306+
def test_web_listen_has_no_uid_dependency(self):
307+
"""web_listen_handler should NOT have uid Depends — uses first-message auth."""
308+
source = self._read_transcribe_source()
309+
import re
310+
311+
web_match = re.search(
312+
r'@router\.websocket\("/v4/web/listen"\)\s*\nasync def web_listen_handler\([^)]+\)',
313+
source,
314+
re.DOTALL,
315+
)
316+
self.assertIsNotNone(web_match, "Could not find /v4/web/listen handler")
317+
handler_sig = web_match.group()
318+
self.assertNotIn(
319+
'get_current_user_uid',
320+
handler_sig,
321+
"/v4/web/listen must NOT have auth dependency — uses accept-first pattern",
322+
)
323+
324+
325+
if __name__ == '__main__':
326+
unittest.main()

0 commit comments

Comments
 (0)