11"""Tests for WebSocket auth handshake fix (#5447).
22
33Verifies that:
4- 1. WebSocket endpoints send proper close frames on auth failure (not HTTPException )
5- 2. Per -UID rate limiting blocks retry storms
4+ 1. get_current_user_uid_ws_listen sends proper close frames on auth failure (no rate limiter )
5+ 2. get_current_user_uid_ws adds per -UID rate limiting on top of auth
663. /v4/web/listen is NOT affected (uses accept-first pattern)
77"""
88
1515from firebase_admin .auth import InvalidIdTokenError
1616from starlette .websockets import WebSocketDisconnect
1717
18- from utils .other .endpoints import get_current_user_uid_ws , get_current_user_uid
18+ from utils .other .endpoints import get_current_user_uid_ws_listen , get_current_user_uid_ws , get_current_user_uid
1919
2020
21- class TestWebSocketAuthDependency (unittest .TestCase ):
22- """Test that get_current_user_uid_ws raises WebSocketException instead of HTTPException ."""
21+ class TestWebSocketAuthListen (unittest .TestCase ):
22+ """Test get_current_user_uid_ws_listen — auth-only, no rate limiter (used by /v4/listen) ."""
2323
2424 def setUp (self ):
2525 self .app = FastAPI ()
2626
27- @self .app .websocket ("/ws-new" )
28- async def ws_new (websocket : WebSocket , uid : str = Depends (get_current_user_uid_ws )):
29- await websocket .accept ()
30- await websocket .send_json ({"uid" : uid })
31- await websocket .close ()
32-
33- @self .app .websocket ("/ws-old" )
34- async def ws_old (websocket : WebSocket , uid : str = Depends (get_current_user_uid )):
27+ @self .app .websocket ("/ws-listen" )
28+ async def ws_listen (websocket : WebSocket , uid : str = Depends (get_current_user_uid_ws_listen )):
3529 await websocket .accept ()
3630 await websocket .send_json ({"uid" : uid })
3731 await websocket .close ()
3832
3933 self .client = TestClient (self .app )
4034
41- def test_ws_new_no_auth_header_sends_close_1008 (self ):
35+ def test_no_auth_header_sends_close_1008 (self ):
4236 """No auth header -> WebSocketDisconnect with code 1008."""
4337 with self .assertRaises (WebSocketDisconnect ) as ctx :
44- with self .client .websocket_connect ("/ws-new " ):
38+ with self .client .websocket_connect ("/ws-listen " ):
4539 self .fail ("Expected WebSocket to be closed by server" )
4640 self .assertEqual (ctx .exception .code , 1008 )
4741
4842 @patch ('utils.other.endpoints.verify_token' , side_effect = InvalidIdTokenError ('Token expired' ))
49- def test_ws_new_invalid_token_sends_close_1008 (self , mock_verify ):
43+ def test_invalid_token_sends_close_1008 (self , mock_verify ):
5044 """Invalid token -> WebSocketDisconnect with code 1008."""
5145 with self .assertRaises (WebSocketDisconnect ) as ctx :
52- with self .client .websocket_connect ("/ws-new " , headers = {"Authorization" : "Bearer invalid_token" }):
46+ with self .client .websocket_connect ("/ws-listen " , headers = {"Authorization" : "Bearer invalid_token" }):
5347 self .fail ("Expected WebSocket to be closed by server" )
5448 self .assertEqual (ctx .exception .code , 1008 )
5549
56- def test_ws_new_malformed_auth_header_sends_close_1008 (self ):
50+ def test_malformed_auth_header_sends_close_1008 (self ):
5751 """Malformed auth header -> WebSocketDisconnect with code 1008."""
5852 with self .assertRaises (WebSocketDisconnect ) as ctx :
59- with self .client .websocket_connect ("/ws-new " , headers = {"Authorization" : "malformed" }):
53+ with self .client .websocket_connect ("/ws-listen " , headers = {"Authorization" : "malformed" }):
6054 self .fail ("Expected WebSocket to be closed by server" )
6155 self .assertEqual (ctx .exception .code , 1008 )
6256
57+ @patch ('utils.other.endpoints.verify_token' , return_value = 'test-uid-123' )
58+ def test_valid_token_connects (self , mock_verify ):
59+ """Valid token -> successful connection (no rate limiter involved)."""
60+ with self .client .websocket_connect ("/ws-listen" , headers = {"Authorization" : "Bearer valid_token" }) as ws :
61+ data = ws .receive_json ()
62+ self .assertEqual (data ["uid" ], "test-uid-123" )
63+ mock_verify .assert_called_once_with ("valid_token" )
64+
65+ def test_empty_bearer_token_sends_close_1008 (self ):
66+ """Authorization: 'Bearer ' (empty token) -> close with 1008."""
67+ with self .assertRaises (WebSocketDisconnect ) as ctx :
68+ with self .client .websocket_connect ("/ws-listen" , headers = {"Authorization" : "Bearer " }):
69+ pass
70+ self .assertEqual (ctx .exception .code , 1008 )
71+
72+ @patch ('utils.other.endpoints.verify_token' , side_effect = RuntimeError ('unexpected error' ))
73+ def test_unexpected_verify_error_sends_close_1008 (self , mock_verify ):
74+ """Unexpected error from verify_token -> close with 1008, not handshake crash."""
75+ with self .assertRaises (WebSocketDisconnect ) as ctx :
76+ with self .client .websocket_connect ("/ws-listen" , headers = {"Authorization" : "Bearer token" }):
77+ self .fail ("Expected connection to fail" )
78+ self .assertEqual (ctx .exception .code , 1008 )
79+
80+ @patch ('utils.other.endpoints.try_acquire_listen_lock' )
81+ @patch ('utils.other.endpoints.verify_token' , return_value = 'test-uid-123' )
82+ def test_no_rate_limiter_called (self , mock_verify , mock_lock ):
83+ """get_current_user_uid_ws_listen must NOT call the rate limiter."""
84+ with self .client .websocket_connect ("/ws-listen" , headers = {"Authorization" : "Bearer valid_token" }) as ws :
85+ data = ws .receive_json ()
86+ self .assertEqual (data ["uid" ], "test-uid-123" )
87+ mock_lock .assert_not_called ()
88+
89+
90+ class TestWebSocketAuthWithRateLimit (unittest .TestCase ):
91+ """Test get_current_user_uid_ws — auth + rate limiting."""
92+
93+ def setUp (self ):
94+ self .app = FastAPI ()
95+
96+ @self .app .websocket ("/ws-ratelimited" )
97+ async def ws_ratelimited (websocket : WebSocket , uid : str = Depends (get_current_user_uid_ws )):
98+ await websocket .accept ()
99+ await websocket .send_json ({"uid" : uid })
100+ await websocket .close ()
101+
102+ self .client = TestClient (self .app )
103+
63104 @patch ('utils.other.endpoints.try_acquire_listen_lock' , return_value = True )
64105 @patch ('utils.other.endpoints.verify_token' , return_value = 'test-uid-123' )
65- def test_ws_new_valid_token_connects (self , mock_verify , mock_lock ):
106+ def test_valid_token_and_lock_connects (self , mock_verify , mock_lock ):
66107 """Valid token + rate limit available -> successful connection."""
67- with self .client .websocket_connect ("/ws-new " , headers = {"Authorization" : "Bearer valid_token" }) as ws :
108+ with self .client .websocket_connect ("/ws-ratelimited " , headers = {"Authorization" : "Bearer valid_token" }) as ws :
68109 data = ws .receive_json ()
69110 self .assertEqual (data ["uid" ], "test-uid-123" )
70111 mock_verify .assert_called_once_with ("valid_token" )
71112 mock_lock .assert_called_once_with ("test-uid-123" )
72113
73114 @patch ('utils.other.endpoints.try_acquire_listen_lock' , return_value = False )
74115 @patch ('utils.other.endpoints.verify_token' , return_value = 'test-uid-456' )
75- def test_ws_new_rate_limited_sends_close_1008 (self , mock_verify , mock_lock ):
116+ def test_rate_limited_sends_close_1008 (self , mock_verify , mock_lock ):
76117 """Valid token but rate limited -> WebSocketDisconnect with code 1008."""
77118 with self .assertRaises (WebSocketDisconnect ) as ctx :
78- with self .client .websocket_connect ("/ws-new" , headers = {"Authorization" : "Bearer valid_token" }):
119+ with self .client .websocket_connect (
120+ "/ws-ratelimited" , headers = {"Authorization" : "Bearer valid_token" }
121+ ):
79122 self .fail ("Expected WebSocket to be closed due to rate limit" )
80123 self .assertEqual (ctx .exception .code , 1008 )
81124 mock_verify .assert_called_once_with ("valid_token" )
82125 mock_lock .assert_called_once_with ("test-uid-456" )
83126
84127 @patch ('utils.other.endpoints.try_acquire_listen_lock' , side_effect = ConnectionError ('redis down' ))
85128 @patch ('utils.other.endpoints.verify_token' , return_value = 'test-uid-789' )
86- def test_ws_new_redis_failure_fails_open (self , mock_verify , mock_lock ):
129+ def test_redis_failure_fails_open (self , mock_verify , mock_lock ):
87130 """Redis failure in rate limiter -> fail-open, connection proceeds."""
88- with self .client .websocket_connect ("/ws-new " , headers = {"Authorization" : "Bearer valid_token" }) as ws :
131+ with self .client .websocket_connect ("/ws-ratelimited " , headers = {"Authorization" : "Bearer valid_token" }) as ws :
89132 data = ws .receive_json ()
90133 self .assertEqual (data ["uid" ], "test-uid-789" )
91134
92135 @patch ('utils.other.endpoints.try_acquire_listen_lock' )
93- def test_ws_new_no_auth_does_not_call_rate_limiter (self , mock_lock ):
136+ def test_no_auth_does_not_call_rate_limiter (self , mock_lock ):
94137 """Missing auth header should short-circuit before rate limiter is called."""
95138 with self .assertRaises (WebSocketDisconnect ) as ctx :
96- with self .client .websocket_connect ("/ws-new " ):
139+ with self .client .websocket_connect ("/ws-ratelimited " ):
97140 pass
98141 self .assertEqual (ctx .exception .code , 1008 )
99142 mock_lock .assert_not_called ()
100143
101144 @patch ('utils.other.endpoints.try_acquire_listen_lock' )
102145 @patch ('utils.other.endpoints.verify_token' , side_effect = InvalidIdTokenError ('expired' ))
103- def test_ws_new_invalid_token_does_not_call_rate_limiter (self , mock_verify , mock_lock ):
146+ def test_invalid_token_does_not_call_rate_limiter (self , mock_verify , mock_lock ):
104147 """Invalid token should short-circuit before rate limiter is called."""
105148 with self .assertRaises (WebSocketDisconnect ) as ctx :
106- with self .client .websocket_connect ("/ws-new " , headers = {"Authorization" : "Bearer bad" }):
149+ with self .client .websocket_connect ("/ws-ratelimited " , headers = {"Authorization" : "Bearer bad" }):
107150 pass
108151 self .assertEqual (ctx .exception .code , 1008 )
109152 mock_lock .assert_not_called ()
110153
111- def test_ws_new_empty_bearer_token_sends_close_1008 (self ):
112- """Authorization: 'Bearer ' (empty token) -> close with 1008."""
113- with self .assertRaises (WebSocketDisconnect ) as ctx :
114- with self .client .websocket_connect ("/ws-new" , headers = {"Authorization" : "Bearer " }):
115- pass
116- self .assertEqual (ctx .exception .code , 1008 )
117-
118- @patch ('utils.other.endpoints.verify_token' , side_effect = RuntimeError ('unexpected error' ))
119- def test_ws_new_unexpected_verify_error_sends_close_1008 (self , mock_verify ):
120- """Unexpected error from verify_token -> close with 1008, not handshake crash."""
121- with self .assertRaises (WebSocketDisconnect ) as ctx :
122- with self .client .websocket_connect ("/ws-new" , headers = {"Authorization" : "Bearer token" }):
123- self .fail ("Expected connection to fail" )
124- self .assertEqual (ctx .exception .code , 1008 )
125-
126154
127155class TestWebSocketCloseFrameBehavior (unittest .TestCase ):
128156 """Test that WebSocketException actually sends ASGI close message (vs HTTPException which doesn't)."""
@@ -232,7 +260,7 @@ async def send(msg):
232260
233261
234262class TestListenEndpointNotAffectWebListen (unittest .TestCase ):
235- """Verify /v4/listen uses WS auth and /v4/web/listen is unchanged (source-level check)."""
263+ """Verify /v4/listen uses WS auth (no rate limiter) and /v4/web/listen is unchanged (source-level check)."""
236264
237265 def _read_transcribe_source (self ):
238266 import os
@@ -241,8 +269,8 @@ def _read_transcribe_source(self):
241269 with open (path ) as f :
242270 return f .read ()
243271
244- def test_listen_handler_uses_http_auth_dependency (self ):
245- """listen_handler should use get_current_user_uid (HTTP variant) — mobile app sends Authorization header ."""
272+ def test_listen_handler_uses_ws_listen_auth (self ):
273+ """listen_handler should use get_current_user_uid_ws_listen (WS auth, no rate limiter) ."""
246274 source = self ._read_transcribe_source ()
247275 import re
248276
@@ -253,9 +281,8 @@ def test_listen_handler_uses_http_auth_dependency(self):
253281 )
254282 self .assertIsNotNone (listen_match , "Could not find /v4/listen handler" )
255283 handler_sig = listen_match .group ()
256- self .assertIn ('get_current_user_uid)' , handler_sig , "/v4/listen must use get_current_user_uid" )
257- self .assertNotIn (
258- 'get_current_user_uid_ws' , handler_sig , "/v4/listen must NOT use get_current_user_uid_ws"
284+ self .assertIn (
285+ 'get_current_user_uid_ws_listen' , handler_sig , "/v4/listen must use get_current_user_uid_ws_listen"
259286 )
260287
261288 def test_web_listen_has_no_uid_dependency (self ):
0 commit comments