Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions config/streaming.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,15 @@ streaming:
buffer_size: 10000
reconnect_attempts: 5
health_check_interval: 30
streaming_health:
monitoring:
enabled: true
check_interval: 5
metrics_retention: 3600
thresholds:
max_latency_ms: 100
min_throughput_msg_per_sec: 50
max_reconnect_attempts: 5
alerts:
enabled: true
channels: ["log", "metrics"]
20 changes: 11 additions & 9 deletions quanttradeai/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Streaming infrastructure package."""

from .gateway import StreamingGateway
from .auth_manager import AuthManager
from .rate_limiter import AdaptiveRateLimiter
from .connection_pool import ConnectionPool
from .gateway import StreamingGateway
from .auth_manager import AuthManager
from .rate_limiter import AdaptiveRateLimiter
from .connection_pool import ConnectionPool
from .monitoring import StreamingHealthMonitor

__all__ = [
"StreamingGateway",
"AuthManager",
"AdaptiveRateLimiter",
"ConnectionPool",
]
"StreamingGateway",
"AuthManager",
"AdaptiveRateLimiter",
"ConnectionPool",
"StreamingHealthMonitor",
]
162 changes: 96 additions & 66 deletions quanttradeai/streaming/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,52 @@
>>> gw.subscribe_to_quotes(["MSFT", "TSLA"], callback=lambda m: print(m))
>>> # gw.start_streaming()
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field

from __future__ import annotations

import asyncio
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Tuple

import yaml

from .stream_buffer import StreamBuffer
from .websocket_manager import WebSocketManager
from .adapters.alpaca_adapter import AlpacaAdapter
from .adapters.ib_adapter import IBAdapter
from .processors import MessageProcessor
from .monitoring.metrics import Metrics
from .logging import logger

AdapterMap = {
"alpaca": AlpacaAdapter,
"interactive_brokers": IBAdapter,
}

Callback = Callable[[str, Dict], None]


@dataclass
class StreamingGateway:
"""Main entry point for real-time data streaming."""

config_path: str
websocket_manager: WebSocketManager = field(init=False)
message_processor: MessageProcessor = field(init=False)
buffer: StreamBuffer = field(init=False)
metrics: Metrics = field(init=False)

import yaml

from .stream_buffer import StreamBuffer
from .websocket_manager import WebSocketManager
from .adapters.alpaca_adapter import AlpacaAdapter
from .adapters.ib_adapter import IBAdapter
from .processors import MessageProcessor
from .monitoring import (
AlertManager,
MetricsCollector,
RecoveryManager,
StreamingHealthMonitor,
)
from .monitoring.metrics import Metrics
from .logging import logger

AdapterMap = {
"alpaca": AlpacaAdapter,
"interactive_brokers": IBAdapter,
}

Callback = Callable[[str, Dict], None]


@dataclass
class StreamingGateway:
"""Main entry point for real-time data streaming."""

config_path: str
websocket_manager: WebSocketManager = field(init=False)
message_processor: MessageProcessor = field(init=False)
buffer: StreamBuffer = field(init=False)
metrics: Metrics = field(init=False)
health_monitor: StreamingHealthMonitor = field(init=False)
_subscriptions: List[Tuple[str, List[str]]] = field(default_factory=list)
_callbacks: Dict[str, List[Callback]] = field(default_factory=dict)
_config_symbols: List[str] = field(default_factory=list)
def __post_init__(self) -> None:

def __post_init__(self) -> None:
with open(self.config_path, "r") as f:
cfg_all = yaml.safe_load(f)
cfg = cfg_all.get("streaming", {})
Expand All @@ -64,6 +71,25 @@ def __post_init__(self) -> None:
self.message_processor = MessageProcessor()
self.buffer = StreamBuffer(cfg.get("buffer_size", 1000))
self.metrics = Metrics()
health_cfg = cfg_all.get("streaming_health", {})
mon_cfg = health_cfg.get("monitoring", {})
thresh_cfg = health_cfg.get("thresholds", {})
alert_cfg = health_cfg.get("alerts", {})
self.health_monitor = StreamingHealthMonitor(
metrics_collector=MetricsCollector(),
alert_manager=AlertManager(alert_cfg.get("channels", ["log"])),
recovery_manager=RecoveryManager(
max_attempts=thresh_cfg.get("max_reconnect_attempts", 5)
),
check_interval=mon_cfg.get("check_interval", 5),
thresholds={
"max_latency_ms": thresh_cfg.get("max_latency_ms", 100),
"min_throughput_msg_per_sec": thresh_cfg.get(
"min_throughput_msg_per_sec", 0
),
"max_reconnect_attempts": thresh_cfg.get("max_reconnect_attempts", 5),
},
)
# Top-level symbol list for convenience
self._config_symbols = cfg.get("symbols", []) or []
for provider_cfg in cfg.get("providers", []):
Expand All @@ -85,41 +111,45 @@ def __post_init__(self) -> None:
for channel in subs:
# Store once; adapters subscribe per provider in _start
self._subscriptions.append((channel, prov_symbols))

# Subscription API -------------------------------------------------
def subscribe_to_trades(self, symbols: List[str], callback: Callback) -> None:
"""Subscribe to trade updates for ``symbols``."""

self._subscriptions.append(("trades", symbols))
self._callbacks.setdefault("trades", []).append(callback)

def subscribe_to_quotes(self, symbols: List[str], callback: Callback) -> None:
"""Subscribe to quote updates for ``symbols``."""

self._subscriptions.append(("quotes", symbols))
self._callbacks.setdefault("quotes", []).append(callback)

# Runtime ----------------------------------------------------------
async def _dispatch(self, provider: str, message: Dict) -> None:
symbol = message.get("symbol", "")
self.metrics.record_message(provider, symbol)
processed = self.message_processor.process(message)
await self.buffer.put(processed)
msg_type = message.get("type", "trades")
logger.debug(
"dispatch_message", provider=provider, type=msg_type, symbol=symbol
)

# Subscription API -------------------------------------------------
def subscribe_to_trades(self, symbols: List[str], callback: Callback) -> None:
"""Subscribe to trade updates for ``symbols``."""

self._subscriptions.append(("trades", symbols))
self._callbacks.setdefault("trades", []).append(callback)

def subscribe_to_quotes(self, symbols: List[str], callback: Callback) -> None:
"""Subscribe to quote updates for ``symbols``."""

self._subscriptions.append(("quotes", symbols))
self._callbacks.setdefault("quotes", []).append(callback)

# Runtime ----------------------------------------------------------
async def _dispatch(self, provider: str, message: Dict) -> None:
symbol = message.get("symbol", "")
self.metrics.record_message(provider, symbol)
self.health_monitor.record_message(provider)
processed = self.message_processor.process(message)
await self.buffer.put(processed)
msg_type = message.get("type", "trades")
logger.debug(
"dispatch_message", provider=provider, type=msg_type, symbol=symbol
)
for cb in self._callbacks.get(msg_type, []):
try:
res = cb(processed)
if asyncio.iscoroutine(res):
await res
except Exception as exc: # keep streaming robust to callback errors
logger.error("callback_error", error=str(exc))

async def _start(self) -> None:
await self.websocket_manager.connect_all()
# Optional health monitoring loop in the background
for adapter in self.websocket_manager.adapters:
self.health_monitor.register_connection(adapter.name)
asyncio.create_task(self.health_monitor.monitor_connection_health())
# Optional legacy health check for connection pool
try:
with open(self.config_path, "r") as f:
cfg = yaml.safe_load(f).get("streaming", {})
Expand All @@ -134,8 +164,8 @@ async def _start(self) -> None:
for adapter in self.websocket_manager.adapters:
await adapter.subscribe(channel, symbols)
await self.websocket_manager.run(self._dispatch)
def start_streaming(self) -> None:
"""Blocking call that starts the streaming event loop."""
asyncio.run(self._start())

def start_streaming(self) -> None:
"""Blocking call that starts the streaming event loop."""

asyncio.run(self._start())
22 changes: 16 additions & 6 deletions quanttradeai/streaming/monitoring/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
"""Monitoring utilities for streaming system."""

from .health_monitor import HealthMonitor
from .metrics import Metrics

__all__ = ["HealthMonitor", "Metrics"]
"""Monitoring utilities for streaming system."""

from .alerts import AlertManager
from .health_monitor import ConnectionHealth, StreamingHealthMonitor
from .metrics import Metrics
from .metrics_collector import MetricsCollector
from .recovery_manager import RecoveryManager

__all__ = [
"AlertManager",
"ConnectionHealth",
"StreamingHealthMonitor",
"MetricsCollector",
"RecoveryManager",
"Metrics",
]
43 changes: 43 additions & 0 deletions quanttradeai/streaming/monitoring/alerts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

"""Alert management utilities for streaming health monitoring."""

from dataclasses import dataclass, field
from typing import Callable, Iterable, List
import logging

logger = logging.getLogger(__name__)


@dataclass
class AlertManager:
"""Dispatch alerts through configurable channels.

Parameters
----------
channels:
Iterable of channels to emit alerts on. Supported values are
``"log"`` and ``"metrics"``. Additional channels can be plugged in by
registering callback functions via :attr:`callbacks`.
"""

channels: Iterable[str] = ("log",)
callbacks: List[Callable[[str, str], None]] = field(default_factory=list)

def send(self, level: str, message: str) -> None:
"""Send an alert with ``level`` and ``message``.

The method logs the alert and notifies any registered callbacks. The
``metrics`` channel is a placeholder for integration with external
monitoring systems and currently acts as a no-op.
"""

if "log" in self.channels:
log_fn = getattr(logger, level, logger.warning)
log_fn(message)
if "metrics" in self.channels:
# Integration point for metric-based alerting (e.g. Prometheus
# counters). Left as a no-op for lightweight deployments.
pass
for cb in self.callbacks:
cb(level, message)
Loading