Skip to content

Commit 1e8a062

Browse files
Debug Agentclaude
andcommitted
fix(security): first-message WebSocket auth to prevent token leakage
Clients can now send {"type": "auth", "session_api_key": "..."} as the first WebSocket frame instead of passing the token as a query parameter. This keeps sk-oh-* tokens out of reverse-proxy / load-balancer access logs (Traefik, Datadog, etc.). Query param and header auth are preserved as deprecated fallbacks for backwards compatibility. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 53b8038 commit 1e8a062

File tree

2 files changed

+320
-9
lines changed

2 files changed

+320
-9
lines changed

openhands-agent-server/openhands/agent_server/sockets.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,18 @@
22
WebSocket endpoints for OpenHands SDK.
33
44
These endpoints are separate from the main API routes to handle WebSocket-specific
5-
authentication. Browsers cannot send custom HTTP headers directly with WebSocket
6-
connections, so we support the `session_api_key` query param. For non-browser
7-
clients (e.g. Python/Node), we also support authenticating via headers.
5+
authentication. Three auth methods are supported (highest to lowest precedence):
6+
7+
1. **First-message auth** (recommended): The client sends
8+
``{"type": "auth", "session_api_key": "..."}`` as the very first WebSocket
9+
frame after the connection opens. This keeps tokens out of URLs and
10+
therefore out of reverse-proxy / load-balancer access logs.
11+
2. Query parameter ``session_api_key`` — deprecated, kept for backwards compat.
12+
3. ``X-Session-API-Key`` header — for non-browser clients.
813
"""
914

15+
import asyncio
16+
import json
1017
import logging
1118
from dataclasses import dataclass
1219
from datetime import datetime
@@ -78,19 +85,68 @@ def _resolve_websocket_session_api_key(
7885
return None
7986

8087

88+
_FIRST_MESSAGE_AUTH_TIMEOUT_SECONDS = 10
89+
90+
8191
async def _accept_authenticated_websocket(
8292
websocket: WebSocket,
8393
session_api_key: str | None,
8494
) -> bool:
85-
"""Authenticate and accept the socket, or close with an auth error."""
95+
"""Authenticate and accept the socket, or close with an auth error.
96+
97+
Authentication is attempted in the following order:
98+
99+
1. Query parameter / header (legacy, deprecated).
100+
2. First-message auth — the client sends
101+
``{"type": "auth", "session_api_key": "..."}`` as the first frame.
102+
103+
The WebSocket is always *accepted* before first-message auth is attempted
104+
because raw WebSocket requires ``accept()`` before any frames can be read.
105+
"""
86106
config = _get_config(websocket)
87107
resolved_key = _resolve_websocket_session_api_key(websocket, session_api_key)
88-
if config.session_api_keys and resolved_key not in config.session_api_keys:
89-
logger.warning("WebSocket authentication failed: invalid or missing API key")
108+
109+
# No auth configured — accept unconditionally.
110+
if not config.session_api_keys:
111+
await websocket.accept()
112+
return True
113+
114+
# Legacy path: key supplied via query param or header.
115+
if resolved_key is not None:
116+
if resolved_key in config.session_api_keys:
117+
logger.warning(
118+
"session_api_key passed via query param or header is deprecated. "
119+
"Use first-message auth instead."
120+
)
121+
await websocket.accept()
122+
return True
123+
logger.warning("WebSocket authentication failed: invalid API key")
90124
await websocket.close(code=4001, reason="Authentication failed")
91125
return False
126+
127+
# First-message auth: accept the connection, then read the first frame.
92128
await websocket.accept()
93-
return True
129+
try:
130+
raw = await asyncio.wait_for(
131+
websocket.receive_text(),
132+
timeout=_FIRST_MESSAGE_AUTH_TIMEOUT_SECONDS,
133+
)
134+
data = json.loads(raw)
135+
except (asyncio.TimeoutError, json.JSONDecodeError, WebSocketDisconnect):
136+
logger.warning("WebSocket first-message auth failed: bad or missing payload")
137+
await _safe_close_websocket(websocket, code=4001, reason="Authentication failed")
138+
return False
139+
140+
if (
141+
isinstance(data, dict)
142+
and data.get("type") == "auth"
143+
and data.get("session_api_key") in config.session_api_keys
144+
):
145+
return True
146+
147+
logger.warning("WebSocket first-message auth failed: invalid key or payload")
148+
await _safe_close_websocket(websocket, code=4001, reason="Authentication failed")
149+
return False
94150

95151

96152
@sockets_router.websocket("/events/{conversation_id}")
@@ -329,9 +385,13 @@ async def _send_event(event: Event, websocket: WebSocket):
329385
logger.exception("error_sending_event: %r", event, stack_info=True)
330386

331387

332-
async def _safe_close_websocket(websocket: WebSocket):
388+
async def _safe_close_websocket(
389+
websocket: WebSocket,
390+
code: int = 1000,
391+
reason: str = "Connection closed",
392+
):
333393
try:
334-
await websocket.close(code=1000, reason="Connection closed")
394+
await websocket.close(code=code, reason=reason)
335395
except Exception:
336396
# WebSocket may already be closed or in inconsistent state
337397
logger.debug("WebSocket close failed (may already be closed)")
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""Tests for first-message WebSocket authentication in sockets.py."""
2+
3+
import asyncio
4+
import json
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
from uuid import uuid4
7+
8+
import pytest
9+
from fastapi import WebSocketDisconnect
10+
11+
from openhands.agent_server.sockets import _accept_authenticated_websocket
12+
13+
14+
def _make_mock_websocket(*, headers=None):
15+
"""Build a mock WebSocket with configurable query params and headers."""
16+
ws = MagicMock()
17+
ws.accept = AsyncMock()
18+
ws.receive_text = AsyncMock()
19+
ws.receive_json = AsyncMock()
20+
ws.send_json = AsyncMock()
21+
ws.close = AsyncMock()
22+
ws.headers = headers or {}
23+
return ws
24+
25+
26+
# -- No auth configured (empty session_api_keys) --
27+
28+
29+
@pytest.mark.asyncio
30+
async def test_no_auth_configured_accepts_immediately():
31+
ws = _make_mock_websocket()
32+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
33+
mock_config.return_value.session_api_keys = []
34+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
35+
36+
assert result is True
37+
ws.accept.assert_called_once()
38+
ws.receive_text.assert_not_called()
39+
40+
41+
# -- Legacy query param auth (deprecated) --
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_legacy_query_param_valid_key():
46+
ws = _make_mock_websocket()
47+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
48+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
49+
result = await _accept_authenticated_websocket(ws, session_api_key="sk-oh-valid")
50+
51+
assert result is True
52+
ws.accept.assert_called_once()
53+
ws.receive_text.assert_not_called()
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_legacy_query_param_invalid_key():
58+
ws = _make_mock_websocket()
59+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
60+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
61+
result = await _accept_authenticated_websocket(ws, session_api_key="sk-oh-wrong")
62+
63+
assert result is False
64+
ws.close.assert_called_once_with(code=4001, reason="Authentication failed")
65+
ws.accept.assert_not_called()
66+
67+
68+
# -- Legacy header auth (deprecated) --
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_legacy_header_valid_key():
73+
ws = _make_mock_websocket(headers={"x-session-api-key": "sk-oh-valid"})
74+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
75+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
76+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
77+
78+
assert result is True
79+
ws.accept.assert_called_once()
80+
81+
82+
@pytest.mark.asyncio
83+
async def test_legacy_header_invalid_key():
84+
ws = _make_mock_websocket(headers={"x-session-api-key": "sk-oh-wrong"})
85+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
86+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
87+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
88+
89+
assert result is False
90+
ws.close.assert_called_once_with(code=4001, reason="Authentication failed")
91+
92+
93+
# -- First-message auth --
94+
95+
96+
@pytest.mark.asyncio
97+
async def test_first_message_auth_valid_key():
98+
ws = _make_mock_websocket()
99+
ws.receive_text.return_value = json.dumps(
100+
{"type": "auth", "session_api_key": "sk-oh-valid"}
101+
)
102+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
103+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
104+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
105+
106+
assert result is True
107+
ws.accept.assert_called_once()
108+
ws.receive_text.assert_called_once()
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_first_message_auth_invalid_key():
113+
ws = _make_mock_websocket()
114+
ws.receive_text.return_value = json.dumps(
115+
{"type": "auth", "session_api_key": "sk-oh-wrong"}
116+
)
117+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
118+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
119+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
120+
121+
assert result is False
122+
ws.accept.assert_called_once() # accepted before reading first message
123+
ws.close.assert_called_once_with(code=4001, reason="Authentication failed")
124+
125+
126+
@pytest.mark.asyncio
127+
async def test_first_message_auth_wrong_type_field():
128+
ws = _make_mock_websocket()
129+
ws.receive_text.return_value = json.dumps(
130+
{"type": "message", "session_api_key": "sk-oh-valid"}
131+
)
132+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
133+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
134+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
135+
136+
assert result is False
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_first_message_auth_missing_key_field():
141+
ws = _make_mock_websocket()
142+
ws.receive_text.return_value = json.dumps({"type": "auth"})
143+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
144+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
145+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
146+
147+
assert result is False
148+
149+
150+
@pytest.mark.asyncio
151+
async def test_first_message_auth_malformed_json():
152+
ws = _make_mock_websocket()
153+
ws.receive_text.return_value = "not json at all"
154+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
155+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
156+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
157+
158+
assert result is False
159+
ws.close.assert_called_once_with(code=4001, reason="Authentication failed")
160+
161+
162+
@pytest.mark.asyncio
163+
async def test_first_message_auth_client_disconnects():
164+
ws = _make_mock_websocket()
165+
ws.receive_text.side_effect = WebSocketDisconnect()
166+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
167+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
168+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
169+
170+
assert result is False
171+
172+
173+
@pytest.mark.asyncio
174+
async def test_first_message_auth_timeout():
175+
ws = _make_mock_websocket()
176+
177+
async def slow_receive():
178+
await asyncio.sleep(60)
179+
180+
ws.receive_text.side_effect = slow_receive
181+
182+
with (
183+
patch("openhands.agent_server.sockets.get_default_config") as mock_config,
184+
patch(
185+
"openhands.agent_server.sockets._FIRST_MESSAGE_AUTH_TIMEOUT_SECONDS", 0.05
186+
),
187+
):
188+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
189+
result = await _accept_authenticated_websocket(ws, session_api_key=None)
190+
191+
assert result is False
192+
ws.close.assert_called_once_with(code=4001, reason="Authentication failed")
193+
194+
195+
# -- End-to-end: first-message auth through events_socket --
196+
197+
198+
@pytest.mark.asyncio
199+
async def test_events_socket_first_message_auth_e2e():
200+
"""First-message auth works end-to-end through the events_socket endpoint."""
201+
from openhands.agent_server.event_service import EventService
202+
from openhands.agent_server.sockets import events_socket
203+
204+
ws = _make_mock_websocket()
205+
# First call: auth message (read by _accept_authenticated_websocket via receive_text)
206+
# Then receive_json calls: disconnect
207+
ws.receive_text.return_value = json.dumps(
208+
{"type": "auth", "session_api_key": "sk-oh-valid"}
209+
)
210+
ws.receive_json.side_effect = WebSocketDisconnect()
211+
212+
mock_event_service = MagicMock(spec=EventService)
213+
mock_event_service.subscribe_to_events = AsyncMock(return_value=uuid4())
214+
mock_event_service.unsubscribe_from_events = AsyncMock(return_value=True)
215+
216+
with (
217+
patch(
218+
"openhands.agent_server.sockets.conversation_service"
219+
) as mock_conv_service,
220+
patch("openhands.agent_server.sockets.get_default_config") as mock_config,
221+
):
222+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
223+
mock_conv_service.get_event_service = AsyncMock(
224+
return_value=mock_event_service
225+
)
226+
227+
await events_socket(uuid4(), ws, session_api_key=None)
228+
229+
ws.accept.assert_called_once()
230+
mock_event_service.subscribe_to_events.assert_called_once()
231+
mock_event_service.unsubscribe_from_events.assert_called_once()
232+
233+
234+
@pytest.mark.asyncio
235+
async def test_events_socket_first_message_auth_rejected():
236+
"""events_socket returns early when first-message auth fails."""
237+
from openhands.agent_server.sockets import events_socket
238+
239+
ws = _make_mock_websocket()
240+
ws.receive_text.return_value = json.dumps(
241+
{"type": "auth", "session_api_key": "sk-oh-wrong"}
242+
)
243+
244+
with patch("openhands.agent_server.sockets.get_default_config") as mock_config:
245+
mock_config.return_value.session_api_keys = ["sk-oh-valid"]
246+
247+
await events_socket(uuid4(), ws, session_api_key=None)
248+
249+
ws.accept.assert_called_once()
250+
# Should not proceed to subscribe
251+
ws.receive_json.assert_not_called()

0 commit comments

Comments
 (0)