Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions quanttradeai/streaming/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from .monitoring.metrics import Metrics
from .logging import logger
from .providers import ProviderHealthMonitor

AdapterMap = {
"alpaca": AlpacaAdapter,
Expand All @@ -58,6 +59,7 @@ class StreamingGateway:
buffer: StreamBuffer = field(init=False)
metrics: Metrics = field(init=False)
health_monitor: StreamingHealthMonitor = field(init=False)
provider_monitor: ProviderHealthMonitor = field(init=False)
_subscriptions: List[Tuple[str, List[str]]] = field(default_factory=list)
_callbacks: Dict[str, List[Callback]] = field(default_factory=dict)
_config_symbols: List[str] = field(default_factory=list)
Expand Down Expand Up @@ -98,6 +100,9 @@ def __post_init__(self) -> None:
queue_size_fn=self.buffer.queue.qsize,
queue_name="stream",
)
self.provider_monitor = ProviderHealthMonitor(
streaming_monitor=self.health_monitor
)
api_cfg = health_cfg.get("api", {})
self._api_enabled = api_cfg.get("enabled", False)
self._api_host = api_cfg.get("host", "0.0.0.0")
Expand All @@ -117,6 +122,17 @@ def __post_init__(self) -> None:
rate_limit_cfg=provider_cfg.get("rate_limit"),
auth_method=provider_cfg.get("auth_method", "api_key"),
)

async def _failover(adapter=adapter) -> None:
await self.websocket_manager._connect_with_retry(
adapter,
monitor=self.provider_monitor,
)

self.provider_monitor.register_provider(
adapter.name,
failover_handler=_failover,
)
# Config-driven subscriptions (optional)
subs = provider_cfg.get("subscriptions", []) or []
prov_symbols = provider_cfg.get("symbols", self._config_symbols) or []
Expand Down Expand Up @@ -159,7 +175,7 @@ async def _dispatch(self, provider: str, message: Dict) -> None:
logger.error("callback_error", error=str(exc))

async def _start(self) -> None:
await self.websocket_manager.connect_all()
await self.websocket_manager.connect_all(monitor=self.provider_monitor)
for adapter in self.websocket_manager.adapters:
self.health_monitor.register_connection(adapter.name)
asyncio.create_task(self.health_monitor.monitor_connection_health())
Expand Down Expand Up @@ -188,7 +204,18 @@ async def _start(self) -> None:
)
for channel, symbols in self._subscriptions:
for adapter in self.websocket_manager.adapters:
await adapter.subscribe(channel, symbols)

async def _subscribe(
adapter=adapter,
channel=channel,
symbols=symbols,
) -> None:
await adapter.subscribe(channel, symbols)

await self.provider_monitor.execute_with_health(
adapter.name,
_subscribe,
)
await self.websocket_manager.run(self._dispatch)

def start_streaming(self) -> None:
Expand Down
23 changes: 19 additions & 4 deletions quanttradeai/streaming/websocket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .rate_limiter import AdaptiveRateLimiter
from .connection_pool import ConnectionPool
from .logging import logger
from .providers.health import ProviderHealthMonitor

Callback = Callable[[str, dict], Awaitable[None]]

Expand Down Expand Up @@ -49,11 +50,23 @@ def add_adapter(
adapter.auth_manager = AuthManager(adapter.name)
self.adapters.append(adapter)

async def _connect_with_retry(self, adapter: DataProviderAdapter) -> None:
async def _connect_with_retry(
self,
adapter: DataProviderAdapter,
*,
monitor: Optional[ProviderHealthMonitor] = None,
) -> None:
delay = 1.0
for attempt in range(self.reconnect_attempts):
try:
await adapter.circuit_breaker.call_async(adapter.connect)
if monitor is not None:

async def _operation() -> None:
await adapter.circuit_breaker.call_async(adapter.connect)

await monitor.execute_with_health(adapter.name, _operation)
else:
await adapter.circuit_breaker.call_async(adapter.connect)
return
except CircuitBreakerError as exc:
logger.error("circuit_open", provider=adapter.name, error=str(exc))
Expand All @@ -70,11 +83,13 @@ async def _connect_with_retry(self, adapter: DataProviderAdapter) -> None:
await asyncio.sleep(delay)
delay *= 2

async def connect_all(self) -> None:
async def connect_all(
self, *, monitor: Optional[ProviderHealthMonitor] = None
) -> None:
"""Connect all registered adapters."""

for adapter in self.adapters:
await self._connect_with_retry(adapter)
await self._connect_with_retry(adapter, monitor=monitor)
self.connection_pool._active.update(
ad.connection for ad in self.adapters if ad.connection
)
Expand Down
193 changes: 176 additions & 17 deletions tests/streaming/test_gateway.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,75 @@
import asyncio
import json
import tempfile
import asyncio
import json
import tempfile
from unittest.mock import patch

import yaml

from quanttradeai.streaming import StreamingGateway
import asyncio
import json
import tempfile
import time
from typing import Awaitable, Callable, Dict, List, Optional
from unittest.mock import AsyncMock, patch

import yaml

from quanttradeai.streaming.monitoring import StreamingHealthMonitor

import pytest
from quanttradeai.streaming import StreamingGateway


class StubProviderMonitor:
def __init__(
self,
*,
streaming_monitor: Optional[StreamingHealthMonitor] = None,
**_: Dict,
) -> None:
self.streaming_monitor = streaming_monitor or StreamingHealthMonitor()
self.recovery_manager = self.streaming_monitor.recovery_manager
self.registered: List[str] = []
self.failover_handlers: Dict[str, Callable[[], Awaitable[None]]] = {}
self.execute_calls: List[str] = []
self.record_success_calls: List[float] = []
self.record_failure_calls: List[str] = []
self.status_providers: Dict[str, Callable[[], object]] = {}

def register_provider(
self,
provider_name: str,
*,
failover_handler: Optional[Callable[[], Awaitable[None]]] = None,
status_provider: Optional[Callable[[], object]] = None,
) -> None:
self.streaming_monitor.register_connection(provider_name)
self.registered.append(provider_name)
if failover_handler is not None:
self.failover_handlers[provider_name] = failover_handler
if status_provider is not None:
self.status_providers[provider_name] = status_provider

async def execute_with_health(
self,
provider_name: str,
operation: Callable[[], Awaitable[object]],
*,
fallback: Optional[Callable[[], Awaitable[object]]] = None,
) -> object:
self.execute_calls.append(provider_name)
start = time.perf_counter()
try:
result = await operation()
except Exception as exc:
await self.record_failure(provider_name, exc)
if fallback is not None:
return await fallback()
raise
latency_ms = (time.perf_counter() - start) * 1000.0
await self.record_success(provider_name, latency_ms)
return result

async def record_success(
self, provider_name: str, latency_ms: float, *, bytes_received: int = 0
) -> None:
self.record_success_calls.append(latency_ms)

async def record_failure(self, provider_name: str, error: Exception) -> None:
self.record_failure_calls.append(provider_name)


class FakeConnection:
Expand All @@ -31,11 +92,11 @@ async def close(self):
pass


def test_gateway_streaming():
msg = json.dumps({"type": "trades", "symbol": "TEST", "price": 1})
async def connect(url, *_, **__):
return FakeConnection([msg])
def test_gateway_streaming():
msg = json.dumps({"type": "trades", "symbol": "TEST", "price": 1})

async def connect(url, *_, **__):
return FakeConnection([msg])

async def run_test():
with patch("websockets.connect", new=connect):
Expand Down Expand Up @@ -68,4 +129,102 @@ async def run_test():
except Exception:
pass

asyncio.run(run_test())
asyncio.run(run_test())


@patch("quanttradeai.streaming.gateway.ProviderHealthMonitor", new=StubProviderMonitor)
def test_gateway_registers_providers_and_failover(tmp_path):
cfg = {
"streaming": {
"providers": [
{
"name": "alpaca",
"websocket_url": "ws://test",
"auth_method": "none",
}
]
}
}
config_file = tmp_path / "streaming.yaml"
config_file.write_text(yaml.safe_dump(cfg))
gateway = StreamingGateway(str(config_file))
monitor = gateway.provider_monitor
assert monitor.registered == ["alpaca"]
adapter = gateway.websocket_manager.adapters[0]
gateway.websocket_manager._connect_with_retry = AsyncMock()
failover = monitor.failover_handlers[adapter.name]
asyncio.run(failover())
gateway.websocket_manager._connect_with_retry.assert_awaited_once()
_, kwargs = gateway.websocket_manager._connect_with_retry.await_args
assert kwargs["monitor"] is monitor


@patch("quanttradeai.streaming.gateway.ProviderHealthMonitor", new=StubProviderMonitor)
def test_gateway_start_uses_provider_monitor(tmp_path):
cfg = {
"streaming": {
"providers": [
{
"name": "alpaca",
"websocket_url": "ws://test",
"auth_method": "none",
}
]
}
}
config_file = tmp_path / "streaming.yaml"
config_file.write_text(yaml.safe_dump(cfg))
gateway = StreamingGateway(str(config_file))
adapter = gateway.websocket_manager.adapters[0]
gateway.subscribe_to_trades(["TEST"], callback=lambda _: None)
adapter.subscribe = AsyncMock(return_value=None)
gateway.websocket_manager.connect_all = AsyncMock()
gateway.websocket_manager.run = AsyncMock()
gateway.health_monitor.monitor_connection_health = AsyncMock()

async def run_start():
await gateway._start()

asyncio.run(run_start())

gateway.websocket_manager.connect_all.assert_awaited_once()
_, kwargs = gateway.websocket_manager.connect_all.await_args
assert kwargs["monitor"] is gateway.provider_monitor
adapter.subscribe.assert_awaited_once()
assert gateway.provider_monitor.execute_calls.count("alpaca") == len(
gateway._subscriptions
)


@patch("quanttradeai.streaming.gateway.ProviderHealthMonitor", new=StubProviderMonitor)
def test_websocket_manager_reports_failures(tmp_path):
cfg = {
"streaming": {
"providers": [
{
"name": "alpaca",
"websocket_url": "ws://test",
"auth_method": "none",
}
]
}
}
config_file = tmp_path / "streaming.yaml"
config_file.write_text(yaml.safe_dump(cfg))
gateway = StreamingGateway(str(config_file))
monitor = gateway.provider_monitor
manager = gateway.websocket_manager
adapter = manager.adapters[0]
manager.reconnect_attempts = 1

async def failing_connect():
raise RuntimeError("boom")

adapter.connect = AsyncMock(side_effect=failing_connect)

async def run_connect():
await manager._connect_with_retry(adapter, monitor=monitor)

with pytest.raises(RuntimeError):
asyncio.run(run_connect())
assert monitor.record_failure_calls