88import pytest
99from fastapi import FastAPI
1010from httpx import ASGITransport , AsyncClient
11- from starlette .datastructures import Headers
11+ from starlette .datastructures import Headers , MutableHeaders
1212from starlette .responses import JSONResponse
1313
14- from rock .admin .entrypoints .sandbox_proxy_api import sandbox_proxy_router , set_sandbox_proxy_service
14+ from rock .admin .entrypoints .sandbox_proxy_api import (
15+ sandbox_proxy_router ,
16+ set_sandbox_proxy_service ,
17+ vnc_websocket_proxy ,
18+ websocket_proxy ,
19+ )
1520from rock .sandbox .service .sandbox_proxy_service import SandboxProxyService
1621
22+
23+ def _make_mock_websocket (query_string : str = "" , headers : dict | None = None ) -> MagicMock :
24+ """Build a minimal mock WebSocket for testing handler logic."""
25+ ws = MagicMock ()
26+ ws .close = AsyncMock ()
27+ ws .headers = MutableHeaders (
28+ scope = {"type" : "websocket" , "headers" : [(k .lower ().encode (), v .encode ()) for k , v in (headers or {}).items ()]}
29+ )
30+ return ws
31+
32+
1733# ─────────────────────────────────────────────────────────────────────────────
1834# Fixtures
1935# ─────────────────────────────────────────────────────────────────────────────
@@ -168,10 +184,9 @@ class TestWebsocketProxyPortParam:
168184
169185 async def test_websocket_proxy_passes_port_to_service (self , app ):
170186 """When rock_target_port=8888 is given, service.websocket_proxy should receive port=8888."""
171- a , svc = app
172- client = TestClientWS (a )
173- with client .websocket_connect ("/sandboxes/sb1/proxy/ws?rock_target_port=8888" ):
174- pass
187+ _a , svc = app
188+ ws = _make_mock_websocket ()
189+ await websocket_proxy (ws , id = "sb1" , path = "ws" , rock_target_port = 8888 )
175190
176191 svc .websocket_proxy .assert_called_once ()
177192 call = svc .websocket_proxy .call_args
@@ -180,10 +195,9 @@ async def test_websocket_proxy_passes_port_to_service(self, app):
180195
181196 async def test_websocket_proxy_defaults_to_none_when_no_port (self , app ):
182197 """When rock_target_port is not specified, service.websocket_proxy should receive port=None."""
183- a , svc = app
184- client = TestClientWS (a )
185- with client .websocket_connect ("/sandboxes/sb1/proxy/ws" ):
186- pass
198+ _a , svc = app
199+ ws = _make_mock_websocket ()
200+ await websocket_proxy (ws , id = "sb1" , path = "ws" , rock_target_port = None )
187201
188202 svc .websocket_proxy .assert_called_once ()
189203 call = svc .websocket_proxy .call_args
@@ -192,27 +206,19 @@ async def test_websocket_proxy_defaults_to_none_when_no_port(self, app):
192206
193207 async def test_websocket_proxy_rejects_invalid_port (self , app ):
194208 """When rock_target_port < 1024, websocket connection should close with code 1008."""
195- a , svc = app
196- client = TestClientWS (a )
197- # Port 80 is below 1024 — expect rejection without calling service
198- try :
199- with client .websocket_connect ("/sandboxes/sb1/proxy/ws?rock_target_port=80" ):
200- pass
201- except Exception :
202- pass # Expect disconnect
209+ _a , svc = app
210+ ws = _make_mock_websocket ()
211+ await websocket_proxy (ws , id = "sb1" , path = "ws" , rock_target_port = 80 )
203212
204- # Service should NOT be called for invalid port
205213 svc .websocket_proxy .assert_not_called ()
214+ ws .close .assert_called_once ()
215+ assert ws .close .call_args .kwargs .get ("code" ) == 1008 or ws .close .call_args .args [0 ] == 1008
206216
207217 async def test_websocket_proxy_port_from_header (self , app ):
208218 """When X-ROCK-Target-Port header is given, service.websocket_proxy should receive the port."""
209- a , svc = app
210- client = TestClientWS (a )
211- try :
212- with client .websocket_connect ("/sandboxes/sb1/proxy/ws" , headers = {"X-ROCK-Target-Port" : "8888" }):
213- pass
214- except Exception :
215- pass
219+ _a , svc = app
220+ ws = _make_mock_websocket (headers = {"X-ROCK-Target-Port" : "8888" })
221+ await websocket_proxy (ws , id = "sb1" , path = "ws" , rock_target_port = None )
216222
217223 svc .websocket_proxy .assert_called_once ()
218224 call = svc .websocket_proxy .call_args
@@ -221,17 +227,12 @@ async def test_websocket_proxy_port_from_header(self, app):
221227
222228 async def test_websocket_proxy_port_conflict (self , app ):
223229 """When both header and query param are given, should close with error."""
224- a , svc = app
225- client = TestClientWS (a )
226- try :
227- with client .websocket_connect (
228- "/sandboxes/sb1/proxy/ws?rock_target_port=8000" , headers = {"X-ROCK-Target-Port" : "9000" }
229- ):
230- pass
231- except Exception :
232- pass
230+ _a , svc = app
231+ ws = _make_mock_websocket (headers = {"X-ROCK-Target-Port" : "9000" })
232+ await websocket_proxy (ws , id = "sb1" , path = "ws" , rock_target_port = 8000 )
233233
234234 svc .websocket_proxy .assert_not_called ()
235+ ws .close .assert_called_once ()
235236
236237
237238# ─────────────────────────────────────────────────────────────────────────────
@@ -893,13 +894,9 @@ class TestPathBasedPortWsRouting:
893894
894895 async def test_ws_port_in_path_is_extracted (self , app ):
895896 """WS /proxy/port/8006/ws should forward path='ws' with port=8006."""
896- a , svc = app
897- client = TestClientWS (a )
898- try :
899- with client .websocket_connect ("/sandboxes/sb1/proxy/port/8006/ws" ):
900- pass
901- except Exception :
902- pass
897+ _a , svc = app
898+ ws = _make_mock_websocket ()
899+ await websocket_proxy (ws , id = "sb1" , path = "port/8006/ws" , rock_target_port = None )
903900
904901 svc .websocket_proxy .assert_called_once ()
905902 call = svc .websocket_proxy .call_args
@@ -910,15 +907,12 @@ async def test_ws_port_in_path_is_extracted(self, app):
910907
911908 async def test_ws_port_in_path_with_query_param_conflict (self , app ):
912909 """WS /proxy/port/8006/ws?rock_target_port=9000 should close with error (conflict)."""
913- a , svc = app
914- client = TestClientWS (a )
915- try :
916- with client .websocket_connect ("/sandboxes/sb1/proxy/port/8006/ws?rock_target_port=9000" ):
917- pass
918- except Exception :
919- pass
910+ _a , svc = app
911+ ws = _make_mock_websocket ()
912+ await websocket_proxy (ws , id = "sb1" , path = "port/8006/ws" , rock_target_port = 9000 )
920913
921914 svc .websocket_proxy .assert_not_called ()
915+ ws .close .assert_called_once ()
922916
923917
924918# ─────────────────────────────────────────────────────────────────────────────
@@ -999,13 +993,9 @@ class TestVncWebSocketProxy:
999993
1000994 async def test_vnc_ws_route_forwards_to_port_8006 (self , app ):
1001995 """WS /vnc/ws should forward to port 8006."""
1002- a , svc = app
1003- client = TestClientWS (a )
1004- try :
1005- with client .websocket_connect ("/sandboxes/sb1/vnc/ws" ):
1006- pass
1007- except Exception :
1008- pass
996+ _a , svc = app
997+ ws = _make_mock_websocket ()
998+ await vnc_websocket_proxy (ws , sandbox_id = "sb1" , path = "ws" )
1009999
10101000 svc .websocket_proxy .assert_called_once ()
10111001 call = svc .websocket_proxy .call_args
@@ -1014,13 +1004,9 @@ async def test_vnc_ws_route_forwards_to_port_8006(self, app):
10141004
10151005 async def test_vnc_ws_route_preserves_path (self , app ):
10161006 """WS /vnc/websockify should forward path='websockify'."""
1017- a , svc = app
1018- client = TestClientWS (a )
1019- try :
1020- with client .websocket_connect ("/sandboxes/sb1/vnc/websockify" ):
1021- pass
1022- except Exception :
1023- pass
1007+ _a , svc = app
1008+ ws = _make_mock_websocket ()
1009+ await vnc_websocket_proxy (ws , sandbox_id = "sb1" , path = "websockify" )
10241010
10251011 svc .websocket_proxy .assert_called_once ()
10261012 call = svc .websocket_proxy .call_args
@@ -1029,34 +1015,11 @@ async def test_vnc_ws_route_preserves_path(self, app):
10291015
10301016 async def test_vnc_ws_route_ignores_query_param_port (self , app ):
10311017 """VNC WS proxy should ignore rock_target_port and always use 8006."""
1032- a , svc = app
1033- client = TestClientWS (a )
1034- try :
1035- with client .websocket_connect ("/sandboxes/sb1/vnc/ws?rock_target_port=9000" ):
1036- pass
1037- except Exception :
1038- pass
1018+ _a , svc = app
1019+ ws = _make_mock_websocket ()
1020+ await vnc_websocket_proxy (ws , sandbox_id = "sb1" , path = "ws" )
10391021
10401022 svc .websocket_proxy .assert_called_once ()
10411023 call = svc .websocket_proxy .call_args
10421024 port = call .kwargs .get ("port" ) or (call .args [3 ] if len (call .args ) > 3 else None )
10431025 assert port == 8006
1044-
1045-
1046- # ─────────────────────────────────────────────────────────────────────────────
1047- # Helper — sync WebSocket test client wrapper
1048- # ─────────────────────────────────────────────────────────────────────────────
1049-
1050-
1051- class TestClientWS :
1052- """Thin wrapper around FastAPI TestClient for WebSocket connections."""
1053-
1054- def __init__ (self , app ):
1055- from fastapi .testclient import TestClient
1056-
1057- self ._client = TestClient (app , raise_server_exceptions = False )
1058-
1059- def websocket_connect (self , path , headers = None ):
1060- if headers is None :
1061- headers = {}
1062- return self ._client .websocket_connect (path , headers = headers )
0 commit comments