Skip to content

Commit e0da1da

Browse files
authored
Merge pull request #50 from AKKI0511/analyze-quanttradeai-for-new-feature-and-improvements
feat(streaming): integrate provider health monitoring
2 parents a148789 + bd2ff34 commit e0da1da

File tree

4 files changed

+231
-24
lines changed

4 files changed

+231
-24
lines changed

quanttradeai/streaming/gateway.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from .monitoring.metrics import Metrics
4141
from .logging import logger
42+
from .providers import ProviderHealthMonitor
4243

4344
AdapterMap = {
4445
"alpaca": AlpacaAdapter,
@@ -58,6 +59,7 @@ class StreamingGateway:
5859
buffer: StreamBuffer = field(init=False)
5960
metrics: Metrics = field(init=False)
6061
health_monitor: StreamingHealthMonitor = field(init=False)
62+
provider_monitor: ProviderHealthMonitor = field(init=False)
6163
_subscriptions: List[Tuple[str, List[str]]] = field(default_factory=list)
6264
_callbacks: Dict[str, List[Callback]] = field(default_factory=dict)
6365
_config_symbols: List[str] = field(default_factory=list)
@@ -98,6 +100,9 @@ def __post_init__(self) -> None:
98100
queue_size_fn=self.buffer.queue.qsize,
99101
queue_name="stream",
100102
)
103+
self.provider_monitor = ProviderHealthMonitor(
104+
streaming_monitor=self.health_monitor
105+
)
101106
api_cfg = health_cfg.get("api", {})
102107
self._api_enabled = api_cfg.get("enabled", False)
103108
self._api_host = api_cfg.get("host", "0.0.0.0")
@@ -117,6 +122,17 @@ def __post_init__(self) -> None:
117122
rate_limit_cfg=provider_cfg.get("rate_limit"),
118123
auth_method=provider_cfg.get("auth_method", "api_key"),
119124
)
125+
126+
async def _failover(adapter=adapter) -> None:
127+
await self.websocket_manager._connect_with_retry(
128+
adapter,
129+
monitor=self.provider_monitor,
130+
)
131+
132+
self.provider_monitor.register_provider(
133+
adapter.name,
134+
failover_handler=_failover,
135+
)
120136
# Config-driven subscriptions (optional)
121137
subs = provider_cfg.get("subscriptions", []) or []
122138
prov_symbols = provider_cfg.get("symbols", self._config_symbols) or []
@@ -159,7 +175,7 @@ async def _dispatch(self, provider: str, message: Dict) -> None:
159175
logger.error("callback_error", error=str(exc))
160176

161177
async def _start(self) -> None:
162-
await self.websocket_manager.connect_all()
178+
await self.websocket_manager.connect_all(monitor=self.provider_monitor)
163179
for adapter in self.websocket_manager.adapters:
164180
self.health_monitor.register_connection(adapter.name)
165181
asyncio.create_task(self.health_monitor.monitor_connection_health())
@@ -188,7 +204,18 @@ async def _start(self) -> None:
188204
)
189205
for channel, symbols in self._subscriptions:
190206
for adapter in self.websocket_manager.adapters:
191-
await adapter.subscribe(channel, symbols)
207+
208+
async def _subscribe(
209+
adapter=adapter,
210+
channel=channel,
211+
symbols=symbols,
212+
) -> None:
213+
await adapter.subscribe(channel, symbols)
214+
215+
await self.provider_monitor.execute_with_health(
216+
adapter.name,
217+
_subscribe,
218+
)
192219
await self.websocket_manager.run(self._dispatch)
193220

194221
def start_streaming(self) -> None:

quanttradeai/streaming/providers/health.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
from collections import defaultdict, deque
99
from datetime import datetime, timezone
10-
from typing import Awaitable, Callable, Deque, Dict, Optional, TypeVar
10+
from typing import Awaitable, Callable, Deque, Dict, Optional, Set, TypeVar
1111

1212
from pybreaker import CircuitBreaker, CircuitBreakerError
1313

@@ -40,6 +40,7 @@ def __init__(
4040
self._failover_handlers: Dict[str, Callable[[], Awaitable[None]]] = {}
4141
self._circuit_breakers: Dict[str, CircuitBreaker] = {}
4242
self._status_sources: Dict[str, Callable[[], ProviderHealthStatus]] = {}
43+
self._active_failovers: Set[str] = set()
4344
self._lock = asyncio.Lock()
4445
self.error_window = error_window
4546
self.error_threshold = error_threshold
@@ -109,6 +110,9 @@ async def _trigger_failover(self, provider_name: str) -> None:
109110
handler = self._failover_handlers.get(provider_name)
110111
if handler is None:
111112
return
113+
if provider_name in self._active_failovers:
114+
return
115+
self._active_failovers.add(provider_name)
112116
try:
113117
async with self._lock:
114118
status = self._statuses.setdefault(
@@ -123,6 +127,8 @@ async def _trigger_failover(self, provider_name: str) -> None:
123127
"provider_failover_failed",
124128
extra={"provider": provider_name, "error": str(exc)},
125129
)
130+
finally:
131+
self._active_failovers.discard(provider_name)
126132

127133
# ------------------------------------------------------------------
128134
async def execute_with_health(

quanttradeai/streaming/websocket_manager.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .rate_limiter import AdaptiveRateLimiter
1414
from .connection_pool import ConnectionPool
1515
from .logging import logger
16+
from .providers.health import ProviderHealthMonitor
1617

1718
Callback = Callable[[str, dict], Awaitable[None]]
1819

@@ -49,11 +50,23 @@ def add_adapter(
4950
adapter.auth_manager = AuthManager(adapter.name)
5051
self.adapters.append(adapter)
5152

52-
async def _connect_with_retry(self, adapter: DataProviderAdapter) -> None:
53+
async def _connect_with_retry(
54+
self,
55+
adapter: DataProviderAdapter,
56+
*,
57+
monitor: Optional[ProviderHealthMonitor] = None,
58+
) -> None:
5359
delay = 1.0
5460
for attempt in range(self.reconnect_attempts):
5561
try:
56-
await adapter.circuit_breaker.call_async(adapter.connect)
62+
if monitor is not None:
63+
64+
async def _operation() -> None:
65+
await adapter.circuit_breaker.call_async(adapter.connect)
66+
67+
await monitor.execute_with_health(adapter.name, _operation)
68+
else:
69+
await adapter.circuit_breaker.call_async(adapter.connect)
5770
return
5871
except CircuitBreakerError as exc:
5972
logger.error("circuit_open", provider=adapter.name, error=str(exc))
@@ -70,11 +83,13 @@ async def _connect_with_retry(self, adapter: DataProviderAdapter) -> None:
7083
await asyncio.sleep(delay)
7184
delay *= 2
7285

73-
async def connect_all(self) -> None:
86+
async def connect_all(
87+
self, *, monitor: Optional[ProviderHealthMonitor] = None
88+
) -> None:
7489
"""Connect all registered adapters."""
7590

7691
for adapter in self.adapters:
77-
await self._connect_with_retry(adapter)
92+
await self._connect_with_retry(adapter, monitor=monitor)
7893
self.connection_pool._active.update(
7994
ad.connection for ad in self.adapters if ad.connection
8095
)

tests/streaming/test_gateway.py

Lines changed: 176 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,75 @@
1-
import asyncio
2-
import json
3-
import tempfile
4-
import asyncio
5-
import json
6-
import tempfile
7-
from unittest.mock import patch
8-
9-
import yaml
10-
11-
from quanttradeai.streaming import StreamingGateway
1+
import asyncio
2+
import json
3+
import tempfile
4+
import time
5+
from typing import Awaitable, Callable, Dict, List, Optional
6+
from unittest.mock import AsyncMock, patch
7+
8+
import yaml
9+
10+
from quanttradeai.streaming.monitoring import StreamingHealthMonitor
11+
12+
import pytest
13+
from quanttradeai.streaming import StreamingGateway
14+
15+
16+
class StubProviderMonitor:
17+
def __init__(
18+
self,
19+
*,
20+
streaming_monitor: Optional[StreamingHealthMonitor] = None,
21+
**_: Dict,
22+
) -> None:
23+
self.streaming_monitor = streaming_monitor or StreamingHealthMonitor()
24+
self.recovery_manager = self.streaming_monitor.recovery_manager
25+
self.registered: List[str] = []
26+
self.failover_handlers: Dict[str, Callable[[], Awaitable[None]]] = {}
27+
self.execute_calls: List[str] = []
28+
self.record_success_calls: List[float] = []
29+
self.record_failure_calls: List[str] = []
30+
self.status_providers: Dict[str, Callable[[], object]] = {}
31+
32+
def register_provider(
33+
self,
34+
provider_name: str,
35+
*,
36+
failover_handler: Optional[Callable[[], Awaitable[None]]] = None,
37+
status_provider: Optional[Callable[[], object]] = None,
38+
) -> None:
39+
self.streaming_monitor.register_connection(provider_name)
40+
self.registered.append(provider_name)
41+
if failover_handler is not None:
42+
self.failover_handlers[provider_name] = failover_handler
43+
if status_provider is not None:
44+
self.status_providers[provider_name] = status_provider
45+
46+
async def execute_with_health(
47+
self,
48+
provider_name: str,
49+
operation: Callable[[], Awaitable[object]],
50+
*,
51+
fallback: Optional[Callable[[], Awaitable[object]]] = None,
52+
) -> object:
53+
self.execute_calls.append(provider_name)
54+
start = time.perf_counter()
55+
try:
56+
result = await operation()
57+
except Exception as exc:
58+
await self.record_failure(provider_name, exc)
59+
if fallback is not None:
60+
return await fallback()
61+
raise
62+
latency_ms = (time.perf_counter() - start) * 1000.0
63+
await self.record_success(provider_name, latency_ms)
64+
return result
65+
66+
async def record_success(
67+
self, provider_name: str, latency_ms: float, *, bytes_received: int = 0
68+
) -> None:
69+
self.record_success_calls.append(latency_ms)
70+
71+
async def record_failure(self, provider_name: str, error: Exception) -> None:
72+
self.record_failure_calls.append(provider_name)
1273

1374

1475
class FakeConnection:
@@ -31,11 +92,11 @@ async def close(self):
3192
pass
3293

3394

34-
def test_gateway_streaming():
35-
msg = json.dumps({"type": "trades", "symbol": "TEST", "price": 1})
36-
37-
async def connect(url, *_, **__):
38-
return FakeConnection([msg])
95+
def test_gateway_streaming():
96+
msg = json.dumps({"type": "trades", "symbol": "TEST", "price": 1})
97+
98+
async def connect(url, *_, **__):
99+
return FakeConnection([msg])
39100

40101
async def run_test():
41102
with patch("websockets.connect", new=connect):
@@ -68,4 +129,102 @@ async def run_test():
68129
except Exception:
69130
pass
70131

71-
asyncio.run(run_test())
132+
asyncio.run(run_test())
133+
134+
135+
@patch("quanttradeai.streaming.gateway.ProviderHealthMonitor", new=StubProviderMonitor)
136+
def test_gateway_registers_providers_and_failover(tmp_path):
137+
cfg = {
138+
"streaming": {
139+
"providers": [
140+
{
141+
"name": "alpaca",
142+
"websocket_url": "ws://test",
143+
"auth_method": "none",
144+
}
145+
]
146+
}
147+
}
148+
config_file = tmp_path / "streaming.yaml"
149+
config_file.write_text(yaml.safe_dump(cfg))
150+
gateway = StreamingGateway(str(config_file))
151+
monitor = gateway.provider_monitor
152+
assert monitor.registered == ["alpaca"]
153+
adapter = gateway.websocket_manager.adapters[0]
154+
gateway.websocket_manager._connect_with_retry = AsyncMock()
155+
failover = monitor.failover_handlers[adapter.name]
156+
asyncio.run(failover())
157+
gateway.websocket_manager._connect_with_retry.assert_awaited_once()
158+
_, kwargs = gateway.websocket_manager._connect_with_retry.await_args
159+
assert kwargs["monitor"] is monitor
160+
161+
162+
@patch("quanttradeai.streaming.gateway.ProviderHealthMonitor", new=StubProviderMonitor)
163+
def test_gateway_start_uses_provider_monitor(tmp_path):
164+
cfg = {
165+
"streaming": {
166+
"providers": [
167+
{
168+
"name": "alpaca",
169+
"websocket_url": "ws://test",
170+
"auth_method": "none",
171+
}
172+
]
173+
}
174+
}
175+
config_file = tmp_path / "streaming.yaml"
176+
config_file.write_text(yaml.safe_dump(cfg))
177+
gateway = StreamingGateway(str(config_file))
178+
adapter = gateway.websocket_manager.adapters[0]
179+
gateway.subscribe_to_trades(["TEST"], callback=lambda _: None)
180+
adapter.subscribe = AsyncMock(return_value=None)
181+
gateway.websocket_manager.connect_all = AsyncMock()
182+
gateway.websocket_manager.run = AsyncMock()
183+
gateway.health_monitor.monitor_connection_health = AsyncMock()
184+
185+
async def run_start():
186+
await gateway._start()
187+
188+
asyncio.run(run_start())
189+
190+
gateway.websocket_manager.connect_all.assert_awaited_once()
191+
_, kwargs = gateway.websocket_manager.connect_all.await_args
192+
assert kwargs["monitor"] is gateway.provider_monitor
193+
adapter.subscribe.assert_awaited_once()
194+
assert gateway.provider_monitor.execute_calls.count("alpaca") == len(
195+
gateway._subscriptions
196+
)
197+
198+
199+
@patch("quanttradeai.streaming.gateway.ProviderHealthMonitor", new=StubProviderMonitor)
200+
def test_websocket_manager_reports_failures(tmp_path):
201+
cfg = {
202+
"streaming": {
203+
"providers": [
204+
{
205+
"name": "alpaca",
206+
"websocket_url": "ws://test",
207+
"auth_method": "none",
208+
}
209+
]
210+
}
211+
}
212+
config_file = tmp_path / "streaming.yaml"
213+
config_file.write_text(yaml.safe_dump(cfg))
214+
gateway = StreamingGateway(str(config_file))
215+
monitor = gateway.provider_monitor
216+
manager = gateway.websocket_manager
217+
adapter = manager.adapters[0]
218+
manager.reconnect_attempts = 1
219+
220+
async def failing_connect():
221+
raise RuntimeError("boom")
222+
223+
adapter.connect = AsyncMock(side_effect=failing_connect)
224+
225+
async def run_connect():
226+
await manager._connect_with_retry(adapter, monitor=monitor)
227+
228+
with pytest.raises(RuntimeError):
229+
asyncio.run(run_connect())
230+
assert monitor.record_failure_calls

0 commit comments

Comments
 (0)