|
| 1 | +"""High-level streaming gateway orchestrator.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import asyncio |
| 6 | +from dataclasses import dataclass, field |
| 7 | +from typing import Callable, Dict, List, Tuple |
| 8 | + |
| 9 | +import yaml |
| 10 | + |
| 11 | +from .stream_buffer import StreamBuffer |
| 12 | +from .websocket_manager import WebSocketManager |
| 13 | +from .adapters.alpaca_adapter import AlpacaAdapter |
| 14 | +from .adapters.ib_adapter import IBAdapter |
| 15 | +from .processors import MessageProcessor |
| 16 | +from .monitoring.metrics import Metrics |
| 17 | + |
| 18 | +AdapterMap = { |
| 19 | + "alpaca": AlpacaAdapter, |
| 20 | + "interactive_brokers": IBAdapter, |
| 21 | +} |
| 22 | + |
| 23 | +Callback = Callable[[str, Dict], None] |
| 24 | + |
| 25 | + |
| 26 | +@dataclass |
| 27 | +class StreamingGateway: |
| 28 | + """Main entry point for real-time data streaming.""" |
| 29 | + |
| 30 | + config_path: str |
| 31 | + websocket_manager: WebSocketManager = field(init=False) |
| 32 | + message_processor: MessageProcessor = field(init=False) |
| 33 | + buffer: StreamBuffer = field(init=False) |
| 34 | + metrics: Metrics = field(init=False) |
| 35 | + _subscriptions: List[Tuple[str, List[str]]] = field(default_factory=list) |
| 36 | + _callbacks: Dict[str, List[Callback]] = field(default_factory=dict) |
| 37 | + |
| 38 | + def __post_init__(self) -> None: |
| 39 | + with open(self.config_path, "r") as f: |
| 40 | + cfg = yaml.safe_load(f)["streaming"] |
| 41 | + self.websocket_manager = WebSocketManager( |
| 42 | + reconnect_attempts=cfg.get("reconnect_attempts", 5) |
| 43 | + ) |
| 44 | + self.message_processor = MessageProcessor() |
| 45 | + self.buffer = StreamBuffer(cfg.get("buffer_size", 1000)) |
| 46 | + self.metrics = Metrics() |
| 47 | + for provider_cfg in cfg.get("providers", []): |
| 48 | + name = provider_cfg["name"] |
| 49 | + url = provider_cfg["websocket_url"] |
| 50 | + adapter_cls = AdapterMap.get(name) |
| 51 | + if adapter_cls is None: |
| 52 | + raise ValueError(f"Unknown provider: {name}") |
| 53 | + adapter = adapter_cls(websocket_url=url) |
| 54 | + self.websocket_manager.add_adapter(adapter) |
| 55 | + |
| 56 | + # Subscription API ------------------------------------------------- |
| 57 | + def subscribe_to_trades(self, symbols: List[str], callback: Callback) -> None: |
| 58 | + """Subscribe to trade updates for ``symbols``.""" |
| 59 | + |
| 60 | + self._subscriptions.append(("trades", symbols)) |
| 61 | + self._callbacks.setdefault("trades", []).append(callback) |
| 62 | + |
| 63 | + def subscribe_to_quotes(self, symbols: List[str], callback: Callback) -> None: |
| 64 | + """Subscribe to quote updates for ``symbols``.""" |
| 65 | + |
| 66 | + self._subscriptions.append(("quotes", symbols)) |
| 67 | + self._callbacks.setdefault("quotes", []).append(callback) |
| 68 | + |
| 69 | + # Runtime ---------------------------------------------------------- |
| 70 | + async def _dispatch(self, provider: str, message: Dict) -> None: |
| 71 | + self.metrics.increment() |
| 72 | + processed = self.message_processor.process(message) |
| 73 | + await self.buffer.put(processed) |
| 74 | + msg_type = message.get("type", "trades") |
| 75 | + for cb in self._callbacks.get(msg_type, []): |
| 76 | + res = cb(processed) |
| 77 | + if asyncio.iscoroutine(res): |
| 78 | + await res |
| 79 | + |
| 80 | + async def _start(self) -> None: |
| 81 | + await self.websocket_manager.connect_all() |
| 82 | + for channel, symbols in self._subscriptions: |
| 83 | + for adapter in self.websocket_manager.adapters: |
| 84 | + await adapter.subscribe(channel, symbols) |
| 85 | + await self.websocket_manager.run(self._dispatch) |
| 86 | + |
| 87 | + def start_streaming(self) -> None: |
| 88 | + """Blocking call that starts the streaming event loop.""" |
| 89 | + |
| 90 | + asyncio.run(self._start()) |
0 commit comments