Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class Settings(BaseSettings):
cors_origins: str = ""
base_url: str = ""

# Gateway WebSocket Origin header (sent on all gateway WS connections).
# Must match gateway's controlUi.allowedOrigins. Defaults to base_url if empty.
gateway_origin: str = ""

# Security response headers (set to blank to disable a specific header)
security_header_x_content_type_options: str = "nosniff"
security_header_x_frame_options: str = "DENY"
Expand Down
52 changes: 26 additions & 26 deletions backend/app/services/openclaw/gateway_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import websockets
from websockets.exceptions import WebSocketException

from app.core.config import settings
from app.core.logging import TRACE_LEVEL, get_logger
from app.services.openclaw.device_identity import (
build_device_auth_payload,
Expand All @@ -41,6 +42,23 @@
CONTROL_UI_CLIENT_MODE = "ui"
GatewayConnectMode = Literal["device", "control_ui"]


def _get_gateway_origin() -> str:
"""Get the origin header for gateway WebSocket connections.

Falls back to base_url if gateway_origin is not configured.
"""
origin = settings.gateway_origin.strip()
if origin:
return origin
# Fallback to base_url for backward compat / convenience.
base = settings.base_url.strip()
if base:
return base
# Last resort: localhost for local dev.
return "http://localhost:3000"


# NOTE: These are the base gateway methods from the OpenClaw gateway repo.
# The gateway can expose additional methods at runtime via channel plugins.
GATEWAY_METHODS = [
Expand Down Expand Up @@ -212,24 +230,6 @@ def _create_ssl_context(config: GatewayConfig) -> ssl.SSLContext | None:
return ssl_context


def _build_control_ui_origin(gateway_url: str) -> str | None:
parsed = urlparse(gateway_url)
if not parsed.hostname:
return None
if parsed.scheme in {"ws", "http"}:
origin_scheme = "http"
elif parsed.scheme in {"wss", "https"}:
origin_scheme = "https"
else:
return None
host = parsed.hostname
if ":" in host and not host.startswith("["):
host = f"[{host}]"
if parsed.port is not None:
host = f"{host}:{parsed.port}"
return f"{origin_scheme}://{host}"


def _resolve_connect_mode(config: GatewayConfig) -> GatewayConnectMode:
return "control_ui" if config.disable_device_pairing else "device"

Expand Down Expand Up @@ -401,11 +401,11 @@ async def _openclaw_call_once(
config: GatewayConfig,
gateway_url: str,
) -> object:
origin = _build_control_ui_origin(gateway_url) if config.disable_device_pairing else None
# Always send the MC origin header so the gateway accepts the connection
# regardless of the gateway URL (avoids "origin not allowed" rejections).
ssl_context = _create_ssl_context(config)
connect_kwargs: dict[str, Any] = {"ping_interval": None}
if origin is not None:
connect_kwargs["origin"] = origin
origin = _get_gateway_origin()
connect_kwargs: dict[str, Any] = {"ping_interval": None, "additional_headers": {"Origin": origin}}
if ssl_context is not None:
Comment on lines +404 to 409
connect_kwargs["ssl"] = ssl_context
async with websockets.connect(gateway_url, **connect_kwargs) as ws:
Expand All @@ -419,11 +419,11 @@ async def _openclaw_connect_metadata_once(
config: GatewayConfig,
gateway_url: str,
) -> object:
origin = _build_control_ui_origin(gateway_url) if config.disable_device_pairing else None
# Always send the MC origin header so the gateway accepts the connection
# regardless of the gateway URL (avoids "origin not allowed" rejections).
ssl_context = _create_ssl_context(config)
connect_kwargs: dict[str, Any] = {"ping_interval": None}
if origin is not None:
connect_kwargs["origin"] = origin
origin = _get_gateway_origin()
connect_kwargs: dict[str, Any] = {"ping_interval": None, "additional_headers": {"Origin": origin}}
if ssl_context is not None:
connect_kwargs["ssl"] = ssl_context
async with websockets.connect(gateway_url, **connect_kwargs) as ws:
Expand Down
65 changes: 50 additions & 15 deletions backend/tests/test_gateway_rpc_connect_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
GatewayConfig,
OpenClawGatewayError,
_build_connect_params,
_build_control_ui_origin,
openclaw_call,
)

Expand Down Expand Up @@ -128,20 +127,6 @@ def _fake_build_device_connect_payload(
assert captured["connect_nonce"] == "nonce-xyz"


@pytest.mark.parametrize(
("gateway_url", "expected_origin"),
[
("ws://gateway.example/ws", "http://gateway.example"),
("wss://gateway.example/ws", "https://gateway.example"),
("ws://gateway.example:8080/ws", "http://gateway.example:8080"),
("wss://gateway.example:8443/ws", "https://gateway.example:8443"),
("ws://[::1]:8000/ws", "http://[::1]:8000"),
],
)
def test_build_control_ui_origin(gateway_url: str, expected_origin: str) -> None:
assert _build_control_ui_origin(gateway_url) == expected_origin


@pytest.mark.asyncio
async def test_openclaw_call_uses_single_connect_attempt(
monkeypatch: pytest.MonkeyPatch,
Expand Down Expand Up @@ -228,6 +213,7 @@ async def _fake_send_request(_ws: object, _method: str, _params: object) -> obje
monkeypatch.setattr(gateway_rpc, "_recv_first_message_or_none", _fake_recv_first)
monkeypatch.setattr(gateway_rpc, "_ensure_connected", _fake_ensure_connected)
monkeypatch.setattr(gateway_rpc, "_send_request", _fake_send_request)
monkeypatch.setattr(gateway_rpc.settings, "gateway_origin", "https://mc.example")

payload = await gateway_rpc._openclaw_call_once(
"status",
Expand All @@ -241,6 +227,9 @@ async def _fake_send_request(_ws: object, _method: str, _params: object) -> obje
kwargs = captured["kwargs"]
assert isinstance(kwargs, dict)
assert "ssl" not in kwargs
headers = kwargs.get("additional_headers")
assert isinstance(headers, dict)
assert headers.get("Origin") == "https://mc.example"


@pytest.mark.asyncio
Expand Down Expand Up @@ -269,6 +258,7 @@ async def _fake_send_request(_ws: object, _method: str, _params: object) -> obje
monkeypatch.setattr(gateway_rpc, "_recv_first_message_or_none", _fake_recv_first)
monkeypatch.setattr(gateway_rpc, "_ensure_connected", _fake_ensure_connected)
monkeypatch.setattr(gateway_rpc, "_send_request", _fake_send_request)
monkeypatch.setattr(gateway_rpc.settings, "gateway_origin", "https://mc.example")

payload = await gateway_rpc._openclaw_call_once(
"status",
Expand All @@ -282,3 +272,48 @@ async def _fake_send_request(_ws: object, _method: str, _params: object) -> obje
kwargs = captured["kwargs"]
assert isinstance(kwargs, dict)
assert kwargs.get("ssl") is not None
headers = kwargs.get("additional_headers")
assert isinstance(headers, dict)
assert headers.get("Origin") == "https://mc.example"


@pytest.mark.asyncio
async def test_openclaw_connect_metadata_once_passes_origin_header(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}

def _fake_connect(url: str, **kwargs: object) -> _FakeConnectContext:
captured["url"] = url
captured["kwargs"] = kwargs
return _FakeConnectContext()

async def _fake_recv_first(_ws: object) -> None:
return None

async def _fake_ensure_connected(
_ws: object, _first_message: object, _config: GatewayConfig
) -> None:
return None

async def _fake_send_request(_ws: object, _method: str, _params: object) -> object:
return {"ok": True}

monkeypatch.setattr(gateway_rpc.websockets, "connect", _fake_connect)
monkeypatch.setattr(gateway_rpc, "_recv_first_message_or_none", _fake_recv_first)
monkeypatch.setattr(gateway_rpc, "_ensure_connected", _fake_ensure_connected)
monkeypatch.setattr(gateway_rpc, "_send_request", _fake_send_request)
monkeypatch.setattr(gateway_rpc.settings, "gateway_origin", "https://mc.example")

payload = await gateway_rpc._openclaw_connect_metadata_once(
config=GatewayConfig(url="wss://gateway.example/ws", allow_insecure_tls=False),
gateway_url="wss://gateway.example/ws",
)

assert payload == {"ok": True}
assert captured["url"] == "wss://gateway.example/ws"
kwargs = captured["kwargs"]
assert isinstance(kwargs, dict)
headers = kwargs.get("additional_headers")
assert isinstance(headers, dict)
assert headers.get("Origin") == "https://mc.example"
Loading