diff --git a/README.md b/README.md index 28db5c5..d996adf 100644 --- a/README.md +++ b/README.md @@ -148,15 +148,53 @@ streaming: health_check_interval: 30 ``` -```python -from quanttradeai.streaming import StreamingGateway - -gw = StreamingGateway("config/streaming.yaml") -gw.subscribe_to_trades(["AAPL"], lambda m: print("TRADE", m)) -# gw.start_streaming() # blocking -``` - -## Project Layout +```python +from quanttradeai.streaming import StreamingGateway + +gw = StreamingGateway("config/streaming.yaml") +gw.subscribe_to_trades(["AAPL"], lambda m: print("TRADE", m)) +# gw.start_streaming() # blocking +``` + +### Streaming Health Monitoring + +Enable advanced monitoring by adding a `streaming_health` section to your config and, +optionally, starting the embedded REST server: + +```yaml +streaming_health: + monitoring: + enabled: true + check_interval: 5 + thresholds: + max_latency_ms: 100 + min_throughput_msg_per_sec: 50 + max_queue_depth: 5000 + alerts: + enabled: true + channels: ["log", "metrics"] + escalation_threshold: 3 + api: + enabled: true + host: "0.0.0.0" + port: 8000 +``` + +Query live status while streaming: + +```bash +curl http://localhost:8000/health # readiness probe +curl http://localhost:8000/status # detailed metrics + incidents +curl http://localhost:8000/metrics # Prometheus scrape +``` + +Common patterns: + +- Tune `escalation_threshold` to control alert promotion. +- Increase `max_queue_depth` in high-volume environments. +- Set `circuit_breaker_timeout` to avoid thrashing unstable providers. + +## Project Layout ``` quanttradeai/ # Core package diff --git a/config/streaming.yaml b/config/streaming.yaml index ecb63df..1bae746 100644 --- a/config/streaming.yaml +++ b/config/streaming.yaml @@ -15,3 +15,22 @@ streaming: buffer_size: 10000 reconnect_attempts: 5 health_check_interval: 30 +streaming_health: + monitoring: + enabled: true + check_interval: 5 + metrics_retention: 3600 + thresholds: + max_latency_ms: 100 + min_throughput_msg_per_sec: 50 + max_reconnect_attempts: 5 + max_queue_depth: 5000 + circuit_breaker_timeout: 60 + alerts: + enabled: true + channels: ["log", "metrics"] + escalation_threshold: 3 + api: + enabled: false + host: "0.0.0.0" + port: 8000 diff --git a/docs/api/streaming.md b/docs/api/streaming.md index 38a2593..2c54fb6 100644 --- a/docs/api/streaming.md +++ b/docs/api/streaming.md @@ -90,3 +90,52 @@ Register it in your runtime (or fork and extend `AdapterMap` in `StreamingGatewa - Prometheus metrics (`prometheus_client`) track message counts, connection latency, and active connections. - Optional background health checks ping pooled connections (interval configured via YAML). +### Advanced Health Metrics + +- **Message loss detection** surfaces gaps in sequence numbers and reports per-provider drop rates. +- **Queue depth gauges** expose backlog in internal processing buffers. +- **Bandwidth and throughput** statistics track messages per second and bytes processed. +- **Data freshness** timers flag stale feeds when updates stop arriving. + +### Alerting & Incident History + +- Configurable thresholds escalate repeated warnings to errors after a defined count. +- Incident history is retained in memory for post-mortem analysis and optional export. +- Alert channels include structured logs and Prometheus-compatible metrics. + +### Recovery & Circuit Breaking + +- Automatic retries use exponential backoff with jitter and respect circuit-breaker timeouts. +- Fallback connectors can be configured for provider outages. + +### Health API + +When enabled, an embedded REST server exposes: + +- `GET /health` – lightweight readiness probe. +- `GET /status` – detailed status including recent incidents. +- `GET /metrics` – Prometheus scrape endpoint. + +### Configuration Example + +```yaml +streaming_health: + monitoring: + enabled: true + check_interval: 5 + thresholds: + max_latency_ms: 100 + min_throughput_msg_per_sec: 50 + max_queue_depth: 5000 + circuit_breaker_timeout: 60 + alerts: + enabled: true + channels: ["log", "metrics"] + escalation_threshold: 3 + api: + enabled: true + host: "0.0.0.0" + port: 8000 +``` + + diff --git a/docs/examples/streaming.md b/docs/examples/streaming.md index 9fdc01a..c34cfba 100644 --- a/docs/examples/streaming.md +++ b/docs/examples/streaming.md @@ -75,3 +75,45 @@ manager.add_adapter(MyAdapter(websocket_url="wss://example"), auth_method="none" # Use StreamingGateway or call manager.connect_all()/run() directly. ``` +## Health Monitoring Configuration + +Append the following to `config/streaming.yaml` to enable comprehensive health checks: + +```yaml +streaming_health: + monitoring: + enabled: true + check_interval: 5 + thresholds: + max_latency_ms: 100 + min_throughput_msg_per_sec: 50 + max_queue_depth: 5000 + circuit_breaker_timeout: 60 + alerts: + enabled: true + channels: ["log", "metrics"] + escalation_threshold: 3 + api: + enabled: true + host: "0.0.0.0" + port: 8000 +``` + +Start the gateway and query metrics: + +```bash +curl http://localhost:8000/status +``` + +## Alert Threshold Tuning + +- Increase `max_latency_ms` or decrease `escalation_threshold` for noisy networks. +- Monitor `max_queue_depth` during peak sessions and adjust to avoid drops. +- Use Prometheus metrics to derive realistic throughput baselines. + +## Production Deployment Recommendations + +- Run the health API behind an authenticated ingress when exposing publicly. +- Scrape `/metrics` with Prometheus and forward alerts to your incident system. +- Configure fallback providers to ensure continuity during outages. + diff --git a/poetry.lock b/poetry.lock index 778a8d4..e71fa94 100644 --- a/poetry.lock +++ b/poetry.lock @@ -984,6 +984,28 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""] +[[package]] +name = "fastapi" +version = "0.116.1" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "fastapi-0.116.1-py3-none-any.whl", hash = "sha256:c46ac7c312df840f0c9e220f7964bada936781bc4e2e6eb71f1c4d7553786565"}, + {file = "fastapi-0.116.1.tar.gz", hash = "sha256:ed52cbf946abfd70c5a0dccb24673f0670deeb517a88b3544d03c2a6bf283143"}, +] + +[package.dependencies] +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" +starlette = ">=0.40.0,<0.48.0" +typing-extensions = ">=4.8.0" + +[package.extras] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0)", "jinja2 (>=3.1.5)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"] +standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0)", "jinja2 (>=3.1.5)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"] + [[package]] name = "fastjsonschema" version = "2.21.2" @@ -4885,6 +4907,25 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "starlette" +version = "0.47.3" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "starlette-0.47.3-py3-none-any.whl", hash = "sha256:89c0778ca62a76b826101e7c709e70680a1699ca7da6b44d38eb0a7e61fe4b51"}, + {file = "starlette-0.47.3.tar.gz", hash = "sha256:6bc94f839cc176c4858894f1f8908f0ab79dfec1a6b8402f6da9be26ebea52e9"}, +] + +[package.dependencies] +anyio = ">=3.6.2,<5" +typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""} + +[package.extras] +full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] + [[package]] name = "structlog" version = "25.4.0" @@ -5193,6 +5234,25 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvicorn" +version = "0.35.0" +description = "The lightning-fast ASGI server." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a"}, + {file = "uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" + +[package.extras] +standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "virtualenv" version = "20.34.0" @@ -5560,4 +5620,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<4.0" -content-hash = "96509fa1b271989d2fcc2cce20e4967a3a8e2d037ab1174fc4edf327192634c0" +content-hash = "15b2de391055d4b20e6b8582b24f22ada29819f7ec03e02774230c0a9b072776" diff --git a/pyproject.toml b/pyproject.toml index 1e2b0cb..331c3ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,9 @@ dependencies = [ "pybreaker (>=1.4.0,<2.0.0)", "asyncio-throttle (>=1.0.2,<2.0.0)", "prometheus-client (>=0.22.1,<0.23.0)", - "typer (>=0.16.1,<0.17.0)" + "typer (>=0.16.1,<0.17.0)", + "fastapi (>=0.115.0,<1.0.0)", + "uvicorn (>=0.30.0,<1.0.0)" ] [tool.poetry] diff --git a/quanttradeai/streaming/__init__.py b/quanttradeai/streaming/__init__.py index 473f8fa..0475555 100644 --- a/quanttradeai/streaming/__init__.py +++ b/quanttradeai/streaming/__init__.py @@ -1,13 +1,15 @@ """Streaming infrastructure package.""" -from .gateway import StreamingGateway -from .auth_manager import AuthManager -from .rate_limiter import AdaptiveRateLimiter -from .connection_pool import ConnectionPool +from .gateway import StreamingGateway +from .auth_manager import AuthManager +from .rate_limiter import AdaptiveRateLimiter +from .connection_pool import ConnectionPool +from .monitoring import StreamingHealthMonitor __all__ = [ - "StreamingGateway", - "AuthManager", - "AdaptiveRateLimiter", - "ConnectionPool", -] + "StreamingGateway", + "AuthManager", + "AdaptiveRateLimiter", + "ConnectionPool", + "StreamingHealthMonitor", +] diff --git a/quanttradeai/streaming/gateway.py b/quanttradeai/streaming/gateway.py index 75720a4..919824a 100644 --- a/quanttradeai/streaming/gateway.py +++ b/quanttradeai/streaming/gateway.py @@ -16,45 +16,53 @@ >>> gw.subscribe_to_quotes(["MSFT", "TSLA"], callback=lambda m: print(m)) >>> # gw.start_streaming() """ - -from __future__ import annotations - -import asyncio -from dataclasses import dataclass, field + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field from typing import Callable, Dict, List, Tuple - -import yaml - -from .stream_buffer import StreamBuffer -from .websocket_manager import WebSocketManager -from .adapters.alpaca_adapter import AlpacaAdapter -from .adapters.ib_adapter import IBAdapter -from .processors import MessageProcessor -from .monitoring.metrics import Metrics -from .logging import logger - -AdapterMap = { - "alpaca": AlpacaAdapter, - "interactive_brokers": IBAdapter, -} - -Callback = Callable[[str, Dict], None] - - -@dataclass -class StreamingGateway: - """Main entry point for real-time data streaming.""" - - config_path: str - websocket_manager: WebSocketManager = field(init=False) - message_processor: MessageProcessor = field(init=False) - buffer: StreamBuffer = field(init=False) - metrics: Metrics = field(init=False) + +import yaml + +from .stream_buffer import StreamBuffer +from .websocket_manager import WebSocketManager +from .adapters.alpaca_adapter import AlpacaAdapter +from .adapters.ib_adapter import IBAdapter +from .processors import MessageProcessor +from .monitoring import ( + AlertManager, + MetricsCollector, + RecoveryManager, + StreamingHealthMonitor, + create_health_app, +) +from .monitoring.metrics import Metrics +from .logging import logger + +AdapterMap = { + "alpaca": AlpacaAdapter, + "interactive_brokers": IBAdapter, +} + +Callback = Callable[[str, Dict], None] + + +@dataclass +class StreamingGateway: + """Main entry point for real-time data streaming.""" + + config_path: str + websocket_manager: WebSocketManager = field(init=False) + message_processor: MessageProcessor = field(init=False) + buffer: StreamBuffer = field(init=False) + metrics: Metrics = field(init=False) + health_monitor: StreamingHealthMonitor = field(init=False) _subscriptions: List[Tuple[str, List[str]]] = field(default_factory=list) _callbacks: Dict[str, List[Callback]] = field(default_factory=dict) _config_symbols: List[str] = field(default_factory=list) - - def __post_init__(self) -> None: + + def __post_init__(self) -> None: with open(self.config_path, "r") as f: cfg_all = yaml.safe_load(f) cfg = cfg_all.get("streaming", {}) @@ -64,6 +72,36 @@ def __post_init__(self) -> None: self.message_processor = MessageProcessor() self.buffer = StreamBuffer(cfg.get("buffer_size", 1000)) self.metrics = Metrics() + health_cfg = cfg_all.get("streaming_health", {}) + mon_cfg = health_cfg.get("monitoring", {}) + thresh_cfg = health_cfg.get("thresholds", {}) + alert_cfg = health_cfg.get("alerts", {}) + self.health_monitor = StreamingHealthMonitor( + metrics_collector=MetricsCollector(), + alert_manager=AlertManager( + alert_cfg.get("channels", ["log"]), + escalation_threshold=alert_cfg.get("escalation_threshold", 3), + ), + recovery_manager=RecoveryManager( + max_attempts=thresh_cfg.get("max_reconnect_attempts", 5), + reset_timeout=thresh_cfg.get("circuit_breaker_timeout", 60), + ), + check_interval=mon_cfg.get("check_interval", 5), + thresholds={ + "max_latency_ms": thresh_cfg.get("max_latency_ms", 100), + "min_throughput_msg_per_sec": thresh_cfg.get( + "min_throughput_msg_per_sec", 0 + ), + "max_reconnect_attempts": thresh_cfg.get("max_reconnect_attempts", 5), + "max_queue_depth": thresh_cfg.get("max_queue_depth", 0), + }, + queue_size_fn=self.buffer.queue.qsize, + queue_name="stream", + ) + api_cfg = health_cfg.get("api", {}) + self._api_enabled = api_cfg.get("enabled", False) + self._api_host = api_cfg.get("host", "0.0.0.0") + self._api_port = api_cfg.get("port", 8000) # Top-level symbol list for convenience self._config_symbols = cfg.get("symbols", []) or [] for provider_cfg in cfg.get("providers", []): @@ -85,30 +123,33 @@ def __post_init__(self) -> None: for channel in subs: # Store once; adapters subscribe per provider in _start self._subscriptions.append((channel, prov_symbols)) - - # Subscription API ------------------------------------------------- - def subscribe_to_trades(self, symbols: List[str], callback: Callback) -> None: - """Subscribe to trade updates for ``symbols``.""" - - self._subscriptions.append(("trades", symbols)) - self._callbacks.setdefault("trades", []).append(callback) - - def subscribe_to_quotes(self, symbols: List[str], callback: Callback) -> None: - """Subscribe to quote updates for ``symbols``.""" - - self._subscriptions.append(("quotes", symbols)) - self._callbacks.setdefault("quotes", []).append(callback) - - # Runtime ---------------------------------------------------------- - async def _dispatch(self, provider: str, message: Dict) -> None: - symbol = message.get("symbol", "") - self.metrics.record_message(provider, symbol) - processed = self.message_processor.process(message) - await self.buffer.put(processed) - msg_type = message.get("type", "trades") - logger.debug( - "dispatch_message", provider=provider, type=msg_type, symbol=symbol - ) + + # Subscription API ------------------------------------------------- + def subscribe_to_trades(self, symbols: List[str], callback: Callback) -> None: + """Subscribe to trade updates for ``symbols``.""" + + self._subscriptions.append(("trades", symbols)) + self._callbacks.setdefault("trades", []).append(callback) + + def subscribe_to_quotes(self, symbols: List[str], callback: Callback) -> None: + """Subscribe to quote updates for ``symbols``.""" + + self._subscriptions.append(("quotes", symbols)) + self._callbacks.setdefault("quotes", []).append(callback) + + # Runtime ---------------------------------------------------------- + async def _dispatch(self, provider: str, message: Dict) -> None: + symbol = message.get("symbol", "") + self.metrics.record_message(provider, symbol) + seq = message.get("sequence") + size = len(str(message)) + self.health_monitor.record_message(provider, sequence=seq, size_bytes=size) + processed = self.message_processor.process(message) + await self.buffer.put(processed) + msg_type = message.get("type", "trades") + logger.debug( + "dispatch_message", provider=provider, type=msg_type, symbol=symbol + ) for cb in self._callbacks.get(msg_type, []): try: res = cb(processed) @@ -116,10 +157,25 @@ async def _dispatch(self, provider: str, message: Dict) -> None: await res except Exception as exc: # keep streaming robust to callback errors logger.error("callback_error", error=str(exc)) - + async def _start(self) -> None: await self.websocket_manager.connect_all() - # Optional health monitoring loop in the background + for adapter in self.websocket_manager.adapters: + self.health_monitor.register_connection(adapter.name) + asyncio.create_task(self.health_monitor.monitor_connection_health()) + if self._api_enabled: + try: + import uvicorn + + app = create_health_app(self.health_monitor) + config = uvicorn.Config( + app, host=self._api_host, port=self._api_port, log_level="warning" + ) + server = uvicorn.Server(config) + asyncio.create_task(server.serve()) + except Exception: + logger.warning("health_api_start_failed") + # Optional legacy health check for connection pool try: with open(self.config_path, "r") as f: cfg = yaml.safe_load(f).get("streaming", {}) @@ -134,8 +190,8 @@ async def _start(self) -> None: for adapter in self.websocket_manager.adapters: await adapter.subscribe(channel, symbols) await self.websocket_manager.run(self._dispatch) - - def start_streaming(self) -> None: - """Blocking call that starts the streaming event loop.""" - - asyncio.run(self._start()) + + def start_streaming(self) -> None: + """Blocking call that starts the streaming event loop.""" + + asyncio.run(self._start()) diff --git a/quanttradeai/streaming/monitoring/__init__.py b/quanttradeai/streaming/monitoring/__init__.py index eb0aafd..e3a16ee 100644 --- a/quanttradeai/streaming/monitoring/__init__.py +++ b/quanttradeai/streaming/monitoring/__init__.py @@ -1,6 +1,18 @@ -"""Monitoring utilities for streaming system.""" - -from .health_monitor import HealthMonitor -from .metrics import Metrics - -__all__ = ["HealthMonitor", "Metrics"] +"""Monitoring utilities for streaming system.""" + +from .alerts import AlertManager +from .api import create_health_app +from .health_monitor import ConnectionHealth, StreamingHealthMonitor +from .metrics import Metrics +from .metrics_collector import MetricsCollector +from .recovery_manager import RecoveryManager + +__all__ = [ + "AlertManager", + "ConnectionHealth", + "StreamingHealthMonitor", + "MetricsCollector", + "RecoveryManager", + "Metrics", + "create_health_app", +] diff --git a/quanttradeai/streaming/monitoring/alerts.py b/quanttradeai/streaming/monitoring/alerts.py new file mode 100644 index 0000000..9529d27 --- /dev/null +++ b/quanttradeai/streaming/monitoring/alerts.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +"""Alert management utilities for streaming health monitoring.""" + +from dataclasses import dataclass, field +from typing import Callable, Dict, Iterable, List, Tuple +import logging +import time + +logger = logging.getLogger(__name__) + + +@dataclass +class AlertManager: + """Dispatch alerts through configurable channels. + + Parameters + ---------- + channels: + Iterable of channels to emit alerts on. Supported values are + ``"log"`` and ``"metrics"``. Additional channels can be plugged in by + registering callback functions via :attr:`callbacks`. + """ + + channels: Iterable[str] = ("log",) + escalation_threshold: int = 3 + callbacks: List[Callable[[str, str], None]] = field(default_factory=list) + warning_counts: Dict[str, int] = field(default_factory=dict) + history: List[Tuple[float, str, str]] = field(default_factory=list) + + def _dispatch(self, level: str, message: str) -> None: + if "log" in self.channels: + log_fn = getattr(logger, level, logger.warning) + log_fn(message) + if "metrics" in self.channels: + # Integration point for metric-based alerting (e.g. Prometheus + # counters). Left as a no-op for lightweight deployments. + pass + for cb in self.callbacks: + cb(level, message) + + def send(self, level: str, message: str) -> None: + """Send an alert with ``level`` and ``message``. + + Escalates after ``escalation_threshold`` warnings and records a simple + in-memory incident history for later inspection. + """ + + self.history.append((time.time(), level, message)) + if level == "warning": + count = self.warning_counts.get(message, 0) + 1 + self.warning_counts[message] = count + if count >= self.escalation_threshold: + self.warning_counts[message] = 0 + self._dispatch("error", message) + return + self._dispatch(level, message) diff --git a/quanttradeai/streaming/monitoring/api.py b/quanttradeai/streaming/monitoring/api.py new file mode 100644 index 0000000..c862c13 --- /dev/null +++ b/quanttradeai/streaming/monitoring/api.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +"""REST API exposing streaming health metrics.""" + +from fastapi import FastAPI, Response +from prometheus_client import CONTENT_TYPE_LATEST, generate_latest + +from .health_monitor import StreamingHealthMonitor + + +def create_health_app(monitor: StreamingHealthMonitor) -> FastAPI: + """Create a FastAPI app exposing health endpoints.""" + + app = FastAPI() + + @app.get("/health") + @app.get("/status") + async def health() -> dict: # pragma: no cover - simple return + return monitor.generate_health_report() + + @app.get("/metrics") + def metrics() -> Response: # pragma: no cover - simple return + data = generate_latest() + return Response(content=data, media_type=CONTENT_TYPE_LATEST) + + return app diff --git a/quanttradeai/streaming/monitoring/health_monitor.py b/quanttradeai/streaming/monitoring/health_monitor.py index 60160a4..19f71dd 100644 --- a/quanttradeai/streaming/monitoring/health_monitor.py +++ b/quanttradeai/streaming/monitoring/health_monitor.py @@ -1,19 +1,179 @@ -"""Health monitoring utilities.""" - -from __future__ import annotations - -import asyncio -from dataclasses import dataclass - - -@dataclass -class HealthMonitor: - """Periodically perform a no-op health check.""" - - interval: int = 30 - - async def run(self) -> None: - """Run the health check loop.""" - - while True: - await asyncio.sleep(self.interval) +from __future__ import annotations + +"""Streaming connection health monitoring.""" + +import asyncio +import time +from dataclasses import dataclass, field +from typing import Callable, Dict, Optional + +from .alerts import AlertManager +from .metrics_collector import MetricsCollector +from .recovery_manager import RecoveryManager + + +@dataclass +class ConnectionHealth: + """State information about a single streaming connection.""" + + last_message_ts: float = field(default_factory=lambda: time.time()) + messages: int = 0 + bytes_received: int = 0 + window_start: float = field(default_factory=lambda: time.time()) + latency_ms: float = 0.0 + status: str = "connected" + reconnect_attempts: int = 0 + last_sequence: Optional[int] = None + lost_messages: int = 0 + + def record_message( + self, sequence: Optional[int] = None, size_bytes: Optional[int] = None + ) -> None: + self.messages += 1 + self.last_message_ts = time.time() + if size_bytes is not None: + self.bytes_received += size_bytes + if sequence is not None: + if self.last_sequence is not None and sequence > self.last_sequence + 1: + self.lost_messages += sequence - self.last_sequence - 1 + self.last_sequence = sequence + + +@dataclass +class StreamingHealthMonitor: + """Monitor streaming connections and collect metrics.""" + + connection_status: Dict[str, ConnectionHealth] = field(default_factory=dict) + metrics_collector: MetricsCollector = field(default_factory=MetricsCollector) + alert_manager: AlertManager = field(default_factory=AlertManager) + recovery_manager: RecoveryManager = field(default_factory=RecoveryManager) + check_interval: float = 5.0 + thresholds: Dict[str, float] = field(default_factory=dict) + queue_size_fn: Optional[Callable[[], int]] = None + queue_name: str = "stream" + _running: bool = field(default=False, init=False) + + # ------------------------------------------------------------------ + def register_connection(self, name: str) -> None: + self.connection_status.setdefault(name, ConnectionHealth()) + + def record_message( + self, + name: str, + *, + sequence: Optional[int] = None, + size_bytes: Optional[int] = None, + ) -> None: + if name not in self.connection_status: + self.register_connection(name) + self.connection_status[name].record_message(sequence, size_bytes) + + def record_latency(self, name: str, latency_ms: float) -> None: + if name not in self.connection_status: + self.register_connection(name) + self.connection_status[name].latency_ms = latency_ms + + # ------------------------------------------------------------------ + async def monitor_connection_health(self) -> None: + self._running = True + while self._running: + await self._check_connections_once() + await asyncio.sleep(self.check_interval) + + async def _check_connections_once(self) -> None: + now = time.time() + for name, health in self.connection_status.items(): + elapsed = now - health.window_start + throughput = health.messages / elapsed if elapsed > 0 else 0.0 + bandwidth = health.bytes_received / elapsed if elapsed > 0 else 0.0 + age = now - health.last_message_ts + self.metrics_collector.record_throughput(name, throughput) + self.metrics_collector.record_bandwidth(name, bandwidth) + self.metrics_collector.record_data_freshness(name, age) + if health.latency_ms: + self.metrics_collector.record_latency(name, health.latency_ms) + + max_lat = self.thresholds.get("max_latency_ms") + if max_lat is not None and health.latency_ms > max_lat: + self.trigger_alerts("warning", f"{name} latency {health.latency_ms}ms") + + min_tp = self.thresholds.get("min_throughput_msg_per_sec") + if min_tp is not None and throughput < min_tp: + self.trigger_alerts( + "warning", f"{name} throughput {throughput:.2f}/s below threshold" + ) + + if health.lost_messages: + self.metrics_collector.increment_message_loss( + name, health.lost_messages + ) + self.trigger_alerts( + "warning", f"{name} lost {health.lost_messages} messages" + ) + health.lost_messages = 0 + + # Detect stale connections + if age > self.check_interval * 2: + await self.handle_connection_failure(name, health) + + health.window_start = now + health.messages = 0 + health.bytes_received = 0 + + if self.queue_size_fn is not None: + depth = self.queue_size_fn() + self.metrics_collector.record_queue_depth(self.queue_name, depth) + max_depth = self.thresholds.get("max_queue_depth") + if max_depth is not None and depth > max_depth: + self.trigger_alerts( + "warning", f"queue depth {depth} exceeds {max_depth}" + ) + + async def handle_connection_failure( + self, name: str, health: ConnectionHealth + ) -> None: + health.status = "reconnecting" + health.reconnect_attempts += 1 + self.metrics_collector.increment_reconnect(name) + success = await self.recovery_manager.reconnect(name) + if success: + now = time.time() + health.status = "connected" + health.window_start = now + health.last_message_ts = now + health.messages = 0 + else: + health.status = "down" + self.trigger_alerts("error", f"{name} reconnection failed") + max_attempts = self.thresholds.get("max_reconnect_attempts") + if max_attempts is not None and health.reconnect_attempts > max_attempts: + self.trigger_alerts( + "error", + f"{name} exceeded max reconnect attempts: {health.reconnect_attempts}", + ) + + def collect_performance_metrics(self) -> Dict[str, float]: + metrics: Dict[str, float] = {} + try: # best effort; psutil may not be installed + import psutil + + proc = psutil.Process() + metrics["memory_bytes"] = float(proc.memory_info().rss) + metrics["cpu_percent"] = float(psutil.cpu_percent(interval=None)) + except Exception: # pragma: no cover - optional dependency + pass + return metrics + + def trigger_alerts(self, level: str, message: str) -> None: + self.alert_manager.send(level, message) + + def generate_health_report(self) -> Dict[str, Dict[str, float]]: + report: Dict[str, Dict[str, float]] = {} + for name, health in self.connection_status.items(): + report[name] = { + "status": health.status, + "latency_ms": health.latency_ms, + "messages": health.messages, + "reconnect_attempts": health.reconnect_attempts, + } + return report diff --git a/quanttradeai/streaming/monitoring/metrics_collector.py b/quanttradeai/streaming/monitoring/metrics_collector.py new file mode 100644 index 0000000..6d9dbca --- /dev/null +++ b/quanttradeai/streaming/monitoring/metrics_collector.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +"""Prometheus-based metrics collection for streaming health.""" + +from dataclasses import dataclass +from prometheus_client import Counter, Gauge, REGISTRY + + +def _get_metric(metric_cls, name: str, documentation: str, labelnames: list[str]): + try: + return REGISTRY._names_to_collectors[name] # type: ignore[attr-defined] + except KeyError: + return metric_cls(name, documentation, labelnames) + + +# Metric definitions ----------------------------------------------------------- +_throughput = _get_metric( + Gauge, + "stream_message_throughput_per_sec", + "Incoming message rate per connection", + ["connection"], +) +_latency = _get_metric( + Gauge, + "stream_connection_latency_ms", + "Measured ping-pong latency in ms", + ["connection"], +) +_freshness = _get_metric( + Gauge, + "stream_data_freshness_seconds", + "Seconds since the last message was received", + ["connection"], +) +_reconnects = _get_metric( + Counter, + "stream_reconnect_total", + "Total reconnection attempts", + ["connection"], +) +_loss = _get_metric( + Counter, + "stream_message_loss_total", + "Detected missing messages per connection", + ["connection"], +) +_queue_depth = _get_metric( + Gauge, + "stream_buffer_queue_depth", + "Queue size of streaming buffers", + ["buffer"], +) +_bandwidth = _get_metric( + Gauge, + "stream_connection_bandwidth_bytes_per_sec", + "Approximate inbound bandwidth per connection", + ["connection"], +) + + +@dataclass +class MetricsCollector: + """Lightweight wrapper around Prometheus metrics.""" + + def record_throughput(self, connection: str, rate: float) -> None: + _throughput.labels(connection=connection).set(rate) + + def record_latency(self, connection: str, latency_ms: float) -> None: + _latency.labels(connection=connection).set(latency_ms) + + def record_data_freshness(self, connection: str, age: float) -> None: + _freshness.labels(connection=connection).set(age) + + def increment_reconnect(self, connection: str) -> None: + _reconnects.labels(connection=connection).inc() + + def increment_message_loss(self, connection: str, count: int) -> None: + _loss.labels(connection=connection).inc(count) + + def record_queue_depth(self, buffer: str, depth: int) -> None: + _queue_depth.labels(buffer=buffer).set(depth) + + def record_bandwidth(self, connection: str, rate: float) -> None: + _bandwidth.labels(connection=connection).set(rate) diff --git a/quanttradeai/streaming/monitoring/recovery_manager.py b/quanttradeai/streaming/monitoring/recovery_manager.py new file mode 100644 index 0000000..87c7862 --- /dev/null +++ b/quanttradeai/streaming/monitoring/recovery_manager.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +"""Automatic reconnection utilities for streaming.""" + +import asyncio +import random +from dataclasses import dataclass, field +from typing import Awaitable, Callable, Optional +from pybreaker import CircuitBreaker, CircuitBreakerError + +from quanttradeai.streaming.logging import logger + + +@dataclass +class RecoveryManager: + """Handle reconnection attempts with exponential backoff and circuit breaking.""" + + max_attempts: int = 5 + base_delay: float = 1.0 + reset_timeout: float = 60.0 + circuit_breaker: CircuitBreaker = field(init=False) + + def __post_init__(self) -> None: + self.circuit_breaker = CircuitBreaker( + fail_max=self.max_attempts, reset_timeout=self.reset_timeout + ) + + async def reconnect( + self, name: str, connect: Optional[Callable[[], Awaitable[None]]] = None + ) -> bool: + """Attempt to reconnect using exponential backoff. + + If reconnection keeps failing ``max_attempts`` times the internal circuit + breaker opens and further attempts immediately fail until the + ``reset_timeout`` elapses. + """ + + delay = self.base_delay + if connect is None: + + async def connect(): + return None + + for attempt in range(self.max_attempts): + try: + await self.circuit_breaker.call_async(connect) + logger.info("reconnect_success", name=name, attempt=attempt) + return True + except CircuitBreakerError: + logger.error("circuit_open", name=name) + return False + except Exception as exc: # pragma: no cover - best effort logging + logger.warning( + "reconnect_failed", name=name, attempt=attempt, error=str(exc) + ) + jitter = random.random() + await asyncio.sleep(delay + jitter) + delay *= 2 + logger.error("reconnect_exhausted", name=name) + return False diff --git a/tests/streaming/test_health_monitor.py b/tests/streaming/test_health_monitor.py new file mode 100644 index 0000000..5f9ff24 --- /dev/null +++ b/tests/streaming/test_health_monitor.py @@ -0,0 +1,116 @@ +import asyncio +import time +import importlib.util +import pathlib +import sys +import types + +monitoring_dir = pathlib.Path(__file__).resolve().parents[2] / "quanttradeai" / "streaming" / "monitoring" +pkg = types.ModuleType("monitoring") +pkg.__path__ = [str(monitoring_dir)] +sys.modules.setdefault("monitoring", pkg) + +def _load(name: str): + spec = importlib.util.spec_from_file_location(f"monitoring.{name}", monitoring_dir / f"{name}.py") + module = importlib.util.module_from_spec(spec) + module.__package__ = "monitoring" + sys.modules[f"monitoring.{name}"] = module + spec.loader.exec_module(module) + return module + +alerts = _load("alerts") +metrics_collector = _load("metrics_collector") +recovery_manager = _load("recovery_manager") +health = _load("health_monitor") + +AlertManager = alerts.AlertManager +MetricsCollector = metrics_collector.MetricsCollector +RecoveryManager = recovery_manager.RecoveryManager +ConnectionHealth = health.ConnectionHealth +StreamingHealthMonitor = health.StreamingHealthMonitor + + +class CollectingAlertManager(AlertManager): + def __init__(self): + super().__init__(channels=["log"]) + self.records = [] + self.callbacks.append(lambda lvl, msg: self.records.append((lvl, msg))) + + +class DummyRecovery(RecoveryManager): + def __init__(self): + super().__init__(max_attempts=1) + self.called = False + + async def reconnect(self, name: str, connect=None) -> bool: # pragma: no cover - simple override + self.called = True + return True + + +class CountingRecovery(RecoveryManager): + def __init__(self): + super().__init__(max_attempts=3) + self.calls = 0 + + async def reconnect(self, name: str, connect=None) -> bool: # pragma: no cover - simple override + self.calls += 1 + return True + + +async def run_latency_check() -> CollectingAlertManager: + alerts = CollectingAlertManager() + monitor = StreamingHealthMonitor( + connection_status={"c": ConnectionHealth(latency_ms=200)}, + metrics_collector=MetricsCollector(), + alert_manager=alerts, + recovery_manager=DummyRecovery(), + check_interval=0.1, + thresholds={"max_latency_ms": 100, "min_throughput_msg_per_sec": 0}, + ) + await monitor._check_connections_once() + return alerts + + +def test_latency_alert_triggered(): + alerts = asyncio.run(run_latency_check()) + assert any("latency" in msg for _, msg in alerts.records) + + +async def run_recovery_check() -> DummyRecovery: + alerts = CollectingAlertManager() + recovery = DummyRecovery() + stale = ConnectionHealth() + stale.last_message_ts = time.time() - 10 + monitor = StreamingHealthMonitor( + connection_status={"c": stale}, + metrics_collector=MetricsCollector(), + alert_manager=alerts, + recovery_manager=recovery, + check_interval=0.1, + ) + await monitor._check_connections_once() + return recovery + + +def test_stale_connection_triggers_recovery(): + recovery = asyncio.run(run_recovery_check()) + assert recovery.called + + +def test_last_message_ts_reset_after_reconnect(): + alerts = CollectingAlertManager() + recovery = CountingRecovery() + stale = ConnectionHealth() + stale.last_message_ts = time.time() - 10 + monitor = StreamingHealthMonitor( + connection_status={"c": stale}, + metrics_collector=MetricsCollector(), + alert_manager=alerts, + recovery_manager=recovery, + check_interval=0.1, + ) + asyncio.run(monitor._check_connections_once()) + first = recovery.calls + asyncio.run(monitor._check_connections_once()) + assert recovery.calls == first + assert stale.last_message_ts > time.time() - 1 \ No newline at end of file diff --git a/tests/streaming/test_health_monitor_enhanced.py b/tests/streaming/test_health_monitor_enhanced.py new file mode 100644 index 0000000..019d84f --- /dev/null +++ b/tests/streaming/test_health_monitor_enhanced.py @@ -0,0 +1,61 @@ +import asyncio +import time + +from fastapi.testclient import TestClient + +from quanttradeai.streaming.monitoring import ( + AlertManager, + MetricsCollector, + RecoveryManager, + StreamingHealthMonitor, + create_health_app, +) + + +class CollectingAlerts(AlertManager): + def __init__(self, **kwargs): + super().__init__(channels=["log"], **kwargs) + self.records = [] + self.callbacks.append(lambda lvl, msg: self.records.append((lvl, msg))) + + +def test_message_loss_and_queue_depth(): + alerts = CollectingAlerts(escalation_threshold=3) + monitor = StreamingHealthMonitor( + metrics_collector=MetricsCollector(), + alert_manager=alerts, + recovery_manager=RecoveryManager(max_attempts=1), + thresholds={"max_queue_depth": 10}, + queue_size_fn=lambda: 20, + ) + monitor.record_message("c", sequence=1) + monitor.record_message("c", sequence=3) + asyncio.run(monitor._check_connections_once()) + assert any("lost" in m for _, m in alerts.records) + assert any("queue depth" in m for _, m in alerts.records) + + +def test_alert_escalation(): + alerts = CollectingAlerts(escalation_threshold=2) + alerts.send("warning", "issue") + alerts.send("warning", "issue") + assert ("error", "issue") in alerts.records + + +def test_health_api_endpoints(): + monitor = StreamingHealthMonitor() + app = create_health_app(monitor) + client = TestClient(app) + assert client.get("/health").status_code == 200 + assert client.get("/status").status_code == 200 + assert client.get("/metrics").status_code == 200 + + +def test_recovery_manager_circuit_breaker(): + rec = RecoveryManager(max_attempts=1, reset_timeout=1) + + async def bad_connect(): + raise RuntimeError("fail") + + assert not asyncio.run(rec.reconnect("c", bad_connect)) + assert not asyncio.run(rec.reconnect("c", bad_connect))