Skip to content

Commit d97d6fb

Browse files
authored
fix tests for can't run (#700)
* test: add missing dataset config tests and imports * test: add skip marker to cpu limit exception test * test: add proxy enhancements tests imports * fix: update websocket subprotocol handling in sandbox proxy service * chore: update ray temp directory path and add temp_dir parameter * test: remove skip marker from cpu limit test * refactor: remove unused temp_dir config and simplify subprotocol handling * fix: update websocket subprotocol handling in sandbox proxy service * refactor: remove unused temp_dir parameter from ray service config * refactor: add conditional temp_dir parameter in ray service config * refactor: simplify ray service temp dir handling
1 parent ffa9ad1 commit d97d6fb

File tree

9 files changed

+267
-291
lines changed

9 files changed

+267
-291
lines changed

rock-conf/rock-test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ ray:
33
working_dir: ./
44
pip: ./requirements_sandbox_actor.txt
55
namespace: "rock-sandbox-test"
6-
temp_dir: ./.tmp/ray
76

87
warmup:
98
images:

rock/sandbox/service/sandbox_proxy_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ async def websocket_proxy(
213213
target_url = await self.get_sandbox_websocket_url(sandbox_id, target_path, port=port)
214214

215215
client_subprotocols = getattr(client_websocket, "subprotocols", []) or []
216-
upstream_subprotocols = client_subprotocols if client_subprotocols else ["binary"]
216+
upstream_subprotocols = client_subprotocols if client_subprotocols else ["binary", "base64"]
217217

218218
try:
219219
async with websockets.connect(
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import logging
2+
23
import pytest
34

4-
from rock.sdk.common.exceptions import BadRequestRockError, RockException
5-
from rock.sdk.sandbox.config import SandboxConfig
5+
from rock.sdk.common.exceptions import BadRequestRockError
66
from rock.sdk.sandbox.client import Sandbox
7+
from rock.sdk.sandbox.config import SandboxConfig
78

89
logger = logging.getLogger(__name__)
910

11+
1012
@pytest.mark.asyncio
1113
async def test_exception_cpu_limit(sandbox_config: SandboxConfig):
1214
sandbox_config.cpus = 20
1315
sandbox = Sandbox(sandbox_config)
1416
with pytest.raises(BadRequestRockError) as ex:
1517
await sandbox.start()
1618
logger.info(f"Exception: {str(ex.value)}")
17-
await sandbox.stop()
19+
await sandbox.stop()

tests/unit/admin/core/test_ray_service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ async def test_reconnect_ray_calls_ray_shutdown_and_init_and_reset_counters(ray_
4040
runtime_env=ray_service._config.runtime_env,
4141
namespace=ray_service._config.namespace,
4242
resources=ray_service._config.resources,
43+
_temp_dir=ray_service._config.temp_dir,
4344
)
4445

4546
assert service._ray_request_count == 0
@@ -73,6 +74,7 @@ async def test_reconnect_ray_skip_when_reader_exists_and_write_lock_timeout(ray_
7374
assert service._ray_request_count == old_count
7475
assert service._ray_establish_time == old_est
7576

77+
7678
@pytest.mark.need_docker
7779
@pytest.mark.need_ray
7880
@pytest.mark.asyncio

tests/unit/sandbox/test_proxy_enhancements.py

Lines changed: 52 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,28 @@
88
import pytest
99
from fastapi import FastAPI
1010
from httpx import ASGITransport, AsyncClient
11-
from starlette.datastructures import Headers
11+
from starlette.datastructures import Headers, MutableHeaders
1212
from starlette.responses import JSONResponse
1313

14-
from rock.admin.entrypoints.sandbox_proxy_api import sandbox_proxy_router, set_sandbox_proxy_service
14+
from rock.admin.entrypoints.sandbox_proxy_api import (
15+
sandbox_proxy_router,
16+
set_sandbox_proxy_service,
17+
vnc_websocket_proxy,
18+
websocket_proxy,
19+
)
1520
from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService
1621

22+
23+
def _make_mock_websocket(query_string: str = "", headers: dict | None = None) -> MagicMock:
24+
"""Build a minimal mock WebSocket for testing handler logic."""
25+
ws = MagicMock()
26+
ws.close = AsyncMock()
27+
ws.headers = MutableHeaders(
28+
scope={"type": "websocket", "headers": [(k.lower().encode(), v.encode()) for k, v in (headers or {}).items()]}
29+
)
30+
return ws
31+
32+
1733
# ─────────────────────────────────────────────────────────────────────────────
1834
# Fixtures
1935
# ─────────────────────────────────────────────────────────────────────────────
@@ -168,10 +184,9 @@ class TestWebsocketProxyPortParam:
168184

169185
async def test_websocket_proxy_passes_port_to_service(self, app):
170186
"""When rock_target_port=8888 is given, service.websocket_proxy should receive port=8888."""
171-
a, svc = app
172-
client = TestClientWS(a)
173-
with client.websocket_connect("/sandboxes/sb1/proxy/ws?rock_target_port=8888"):
174-
pass
187+
_a, svc = app
188+
ws = _make_mock_websocket()
189+
await websocket_proxy(ws, id="sb1", path="ws", rock_target_port=8888)
175190

176191
svc.websocket_proxy.assert_called_once()
177192
call = svc.websocket_proxy.call_args
@@ -180,10 +195,9 @@ async def test_websocket_proxy_passes_port_to_service(self, app):
180195

181196
async def test_websocket_proxy_defaults_to_none_when_no_port(self, app):
182197
"""When rock_target_port is not specified, service.websocket_proxy should receive port=None."""
183-
a, svc = app
184-
client = TestClientWS(a)
185-
with client.websocket_connect("/sandboxes/sb1/proxy/ws"):
186-
pass
198+
_a, svc = app
199+
ws = _make_mock_websocket()
200+
await websocket_proxy(ws, id="sb1", path="ws", rock_target_port=None)
187201

188202
svc.websocket_proxy.assert_called_once()
189203
call = svc.websocket_proxy.call_args
@@ -192,27 +206,19 @@ async def test_websocket_proxy_defaults_to_none_when_no_port(self, app):
192206

193207
async def test_websocket_proxy_rejects_invalid_port(self, app):
194208
"""When rock_target_port < 1024, websocket connection should close with code 1008."""
195-
a, svc = app
196-
client = TestClientWS(a)
197-
# Port 80 is below 1024 — expect rejection without calling service
198-
try:
199-
with client.websocket_connect("/sandboxes/sb1/proxy/ws?rock_target_port=80"):
200-
pass
201-
except Exception:
202-
pass # Expect disconnect
209+
_a, svc = app
210+
ws = _make_mock_websocket()
211+
await websocket_proxy(ws, id="sb1", path="ws", rock_target_port=80)
203212

204-
# Service should NOT be called for invalid port
205213
svc.websocket_proxy.assert_not_called()
214+
ws.close.assert_called_once()
215+
assert ws.close.call_args.kwargs.get("code") == 1008 or ws.close.call_args.args[0] == 1008
206216

207217
async def test_websocket_proxy_port_from_header(self, app):
208218
"""When X-ROCK-Target-Port header is given, service.websocket_proxy should receive the port."""
209-
a, svc = app
210-
client = TestClientWS(a)
211-
try:
212-
with client.websocket_connect("/sandboxes/sb1/proxy/ws", headers={"X-ROCK-Target-Port": "8888"}):
213-
pass
214-
except Exception:
215-
pass
219+
_a, svc = app
220+
ws = _make_mock_websocket(headers={"X-ROCK-Target-Port": "8888"})
221+
await websocket_proxy(ws, id="sb1", path="ws", rock_target_port=None)
216222

217223
svc.websocket_proxy.assert_called_once()
218224
call = svc.websocket_proxy.call_args
@@ -221,17 +227,12 @@ async def test_websocket_proxy_port_from_header(self, app):
221227

222228
async def test_websocket_proxy_port_conflict(self, app):
223229
"""When both header and query param are given, should close with error."""
224-
a, svc = app
225-
client = TestClientWS(a)
226-
try:
227-
with client.websocket_connect(
228-
"/sandboxes/sb1/proxy/ws?rock_target_port=8000", headers={"X-ROCK-Target-Port": "9000"}
229-
):
230-
pass
231-
except Exception:
232-
pass
230+
_a, svc = app
231+
ws = _make_mock_websocket(headers={"X-ROCK-Target-Port": "9000"})
232+
await websocket_proxy(ws, id="sb1", path="ws", rock_target_port=8000)
233233

234234
svc.websocket_proxy.assert_not_called()
235+
ws.close.assert_called_once()
235236

236237

237238
# ─────────────────────────────────────────────────────────────────────────────
@@ -893,13 +894,9 @@ class TestPathBasedPortWsRouting:
893894

894895
async def test_ws_port_in_path_is_extracted(self, app):
895896
"""WS /proxy/port/8006/ws should forward path='ws' with port=8006."""
896-
a, svc = app
897-
client = TestClientWS(a)
898-
try:
899-
with client.websocket_connect("/sandboxes/sb1/proxy/port/8006/ws"):
900-
pass
901-
except Exception:
902-
pass
897+
_a, svc = app
898+
ws = _make_mock_websocket()
899+
await websocket_proxy(ws, id="sb1", path="port/8006/ws", rock_target_port=None)
903900

904901
svc.websocket_proxy.assert_called_once()
905902
call = svc.websocket_proxy.call_args
@@ -910,15 +907,12 @@ async def test_ws_port_in_path_is_extracted(self, app):
910907

911908
async def test_ws_port_in_path_with_query_param_conflict(self, app):
912909
"""WS /proxy/port/8006/ws?rock_target_port=9000 should close with error (conflict)."""
913-
a, svc = app
914-
client = TestClientWS(a)
915-
try:
916-
with client.websocket_connect("/sandboxes/sb1/proxy/port/8006/ws?rock_target_port=9000"):
917-
pass
918-
except Exception:
919-
pass
910+
_a, svc = app
911+
ws = _make_mock_websocket()
912+
await websocket_proxy(ws, id="sb1", path="port/8006/ws", rock_target_port=9000)
920913

921914
svc.websocket_proxy.assert_not_called()
915+
ws.close.assert_called_once()
922916

923917

924918
# ─────────────────────────────────────────────────────────────────────────────
@@ -999,13 +993,9 @@ class TestVncWebSocketProxy:
999993

1000994
async def test_vnc_ws_route_forwards_to_port_8006(self, app):
1001995
"""WS /vnc/ws should forward to port 8006."""
1002-
a, svc = app
1003-
client = TestClientWS(a)
1004-
try:
1005-
with client.websocket_connect("/sandboxes/sb1/vnc/ws"):
1006-
pass
1007-
except Exception:
1008-
pass
996+
_a, svc = app
997+
ws = _make_mock_websocket()
998+
await vnc_websocket_proxy(ws, sandbox_id="sb1", path="ws")
1009999

10101000
svc.websocket_proxy.assert_called_once()
10111001
call = svc.websocket_proxy.call_args
@@ -1014,13 +1004,9 @@ async def test_vnc_ws_route_forwards_to_port_8006(self, app):
10141004

10151005
async def test_vnc_ws_route_preserves_path(self, app):
10161006
"""WS /vnc/websockify should forward path='websockify'."""
1017-
a, svc = app
1018-
client = TestClientWS(a)
1019-
try:
1020-
with client.websocket_connect("/sandboxes/sb1/vnc/websockify"):
1021-
pass
1022-
except Exception:
1023-
pass
1007+
_a, svc = app
1008+
ws = _make_mock_websocket()
1009+
await vnc_websocket_proxy(ws, sandbox_id="sb1", path="websockify")
10241010

10251011
svc.websocket_proxy.assert_called_once()
10261012
call = svc.websocket_proxy.call_args
@@ -1029,34 +1015,11 @@ async def test_vnc_ws_route_preserves_path(self, app):
10291015

10301016
async def test_vnc_ws_route_ignores_query_param_port(self, app):
10311017
"""VNC WS proxy should ignore rock_target_port and always use 8006."""
1032-
a, svc = app
1033-
client = TestClientWS(a)
1034-
try:
1035-
with client.websocket_connect("/sandboxes/sb1/vnc/ws?rock_target_port=9000"):
1036-
pass
1037-
except Exception:
1038-
pass
1018+
_a, svc = app
1019+
ws = _make_mock_websocket()
1020+
await vnc_websocket_proxy(ws, sandbox_id="sb1", path="ws")
10391021

10401022
svc.websocket_proxy.assert_called_once()
10411023
call = svc.websocket_proxy.call_args
10421024
port = call.kwargs.get("port") or (call.args[3] if len(call.args) > 3 else None)
10431025
assert port == 8006
1044-
1045-
1046-
# ─────────────────────────────────────────────────────────────────────────────
1047-
# Helper — sync WebSocket test client wrapper
1048-
# ─────────────────────────────────────────────────────────────────────────────
1049-
1050-
1051-
class TestClientWS:
1052-
"""Thin wrapper around FastAPI TestClient for WebSocket connections."""
1053-
1054-
def __init__(self, app):
1055-
from fastapi.testclient import TestClient
1056-
1057-
self._client = TestClient(app, raise_server_exceptions=False)
1058-
1059-
def websocket_connect(self, path, headers=None):
1060-
if headers is None:
1061-
headers = {}
1062-
return self._client.websocket_connect(path, headers=headers)

0 commit comments

Comments
 (0)