Skip to content

Commit ec5ccf5

Browse files
committed
Add streaming gateway and WebSocket infrastructure
1 parent 5d4540d commit ec5ccf5

File tree

18 files changed

+425
-2
lines changed

18 files changed

+425
-2
lines changed

config/streaming.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
streaming:
2+
providers:
3+
- name: "alpaca"
4+
websocket_url: "wss://stream.data.alpaca.markets/v2/iex"
5+
auth_method: "api_key"
6+
subscriptions: ["trades", "quotes"]
7+
buffer_size: 10000
8+
reconnect_attempts: 5
9+
health_check_interval: 30

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ dependencies = [
2525
"joblib (>=1.5.1,<2.0.0)",
2626
"pydantic (>=2.11.7,<3.0.0)",
2727
"alpha-vantage (>=3.0.0,<4.0.0)",
28-
"litellm (>=1.75.0,<2.0.0)"
28+
"litellm (>=1.75.0,<2.0.0)",
29+
"websockets (>=15.0.1,<16.0.0)"
2930
]
3031

3132
[tool.poetry]

quanttradeai/streaming/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Streaming infrastructure package."""
2+
3+
from .gateway import StreamingGateway
4+
5+
__all__ = ["StreamingGateway"]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Provider adapter implementations."""
2+
3+
from .base_adapter import DataProviderAdapter
4+
from .alpaca_adapter import AlpacaAdapter
5+
from .ib_adapter import IBAdapter
6+
7+
__all__ = ["DataProviderAdapter", "AlpacaAdapter", "IBAdapter"]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Alpaca data provider adapter."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Any, Dict, List
7+
8+
from .base_adapter import DataProviderAdapter
9+
10+
11+
@dataclass
12+
class AlpacaAdapter(DataProviderAdapter):
13+
"""Adapter for Alpaca's streaming API."""
14+
15+
name: str = "alpaca"
16+
17+
def _build_subscribe_message(
18+
self, channel: str, symbols: List[str]
19+
) -> Dict[str, Any]:
20+
return {"action": "subscribe", channel: symbols}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Base classes for data provider adapters."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
from abc import ABC, abstractmethod
7+
from dataclasses import dataclass, field
8+
from typing import Any, AsyncIterator, Dict, List, Optional
9+
10+
import websockets
11+
12+
13+
@dataclass
14+
class DataProviderAdapter(ABC):
15+
"""Abstract adapter defining basic WebSocket operations."""
16+
17+
websocket_url: str
18+
name: str
19+
connection: Optional[websockets.WebSocketClientProtocol] = field(
20+
default=None, init=False
21+
)
22+
23+
async def connect(self) -> None:
24+
"""Establish a WebSocket connection to the provider."""
25+
26+
self.connection = await websockets.connect(self.websocket_url)
27+
28+
async def subscribe(self, channel: str, symbols: List[str]) -> None:
29+
"""Send a subscription message for ``symbols`` on ``channel``."""
30+
31+
if self.connection is None:
32+
await self.connect()
33+
message = self._build_subscribe_message(channel, symbols)
34+
await self.connection.send(json.dumps(message))
35+
36+
async def listen(self) -> AsyncIterator[Dict[str, Any]]:
37+
"""Yield parsed JSON messages from the connection."""
38+
39+
if self.connection is None:
40+
raise RuntimeError("Connection not established")
41+
async for message in self.connection:
42+
yield json.loads(message)
43+
44+
@abstractmethod
45+
def _build_subscribe_message(
46+
self, channel: str, symbols: List[str]
47+
) -> Dict[str, Any]:
48+
"""Return provider specific subscription payload."""
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Interactive Brokers data provider adapter."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Any, Dict, List
7+
8+
from .base_adapter import DataProviderAdapter
9+
10+
11+
@dataclass
12+
class IBAdapter(DataProviderAdapter):
13+
"""Adapter for Interactive Brokers TWS/Gateway streaming API."""
14+
15+
name: str = "interactive_brokers"
16+
17+
def _build_subscribe_message(
18+
self, channel: str, symbols: List[str]
19+
) -> Dict[str, Any]:
20+
return {"action": "subscribe", "channel": channel, "symbols": symbols}

quanttradeai/streaming/gateway.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""High-level streaming gateway orchestrator."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
from dataclasses import dataclass, field
7+
from typing import Callable, Dict, List, Tuple
8+
9+
import yaml
10+
11+
from .stream_buffer import StreamBuffer
12+
from .websocket_manager import WebSocketManager
13+
from .adapters.alpaca_adapter import AlpacaAdapter
14+
from .adapters.ib_adapter import IBAdapter
15+
from .processors import MessageProcessor
16+
from .monitoring.metrics import Metrics
17+
18+
AdapterMap = {
19+
"alpaca": AlpacaAdapter,
20+
"interactive_brokers": IBAdapter,
21+
}
22+
23+
Callback = Callable[[str, Dict], None]
24+
25+
26+
@dataclass
27+
class StreamingGateway:
28+
"""Main entry point for real-time data streaming."""
29+
30+
config_path: str
31+
websocket_manager: WebSocketManager = field(init=False)
32+
message_processor: MessageProcessor = field(init=False)
33+
buffer: StreamBuffer = field(init=False)
34+
metrics: Metrics = field(init=False)
35+
_subscriptions: List[Tuple[str, List[str]]] = field(default_factory=list)
36+
_callbacks: Dict[str, List[Callback]] = field(default_factory=dict)
37+
38+
def __post_init__(self) -> None:
39+
with open(self.config_path, "r") as f:
40+
cfg = yaml.safe_load(f)["streaming"]
41+
self.websocket_manager = WebSocketManager(
42+
reconnect_attempts=cfg.get("reconnect_attempts", 5)
43+
)
44+
self.message_processor = MessageProcessor()
45+
self.buffer = StreamBuffer(cfg.get("buffer_size", 1000))
46+
self.metrics = Metrics()
47+
for provider_cfg in cfg.get("providers", []):
48+
name = provider_cfg["name"]
49+
url = provider_cfg["websocket_url"]
50+
adapter_cls = AdapterMap.get(name)
51+
if adapter_cls is None:
52+
raise ValueError(f"Unknown provider: {name}")
53+
adapter = adapter_cls(websocket_url=url)
54+
self.websocket_manager.add_adapter(adapter)
55+
56+
# Subscription API -------------------------------------------------
57+
def subscribe_to_trades(self, symbols: List[str], callback: Callback) -> None:
58+
"""Subscribe to trade updates for ``symbols``."""
59+
60+
self._subscriptions.append(("trades", symbols))
61+
self._callbacks.setdefault("trades", []).append(callback)
62+
63+
def subscribe_to_quotes(self, symbols: List[str], callback: Callback) -> None:
64+
"""Subscribe to quote updates for ``symbols``."""
65+
66+
self._subscriptions.append(("quotes", symbols))
67+
self._callbacks.setdefault("quotes", []).append(callback)
68+
69+
# Runtime ----------------------------------------------------------
70+
async def _dispatch(self, provider: str, message: Dict) -> None:
71+
self.metrics.increment()
72+
processed = self.message_processor.process(message)
73+
await self.buffer.put(processed)
74+
msg_type = message.get("type", "trades")
75+
for cb in self._callbacks.get(msg_type, []):
76+
res = cb(processed)
77+
if asyncio.iscoroutine(res):
78+
await res
79+
80+
async def _start(self) -> None:
81+
await self.websocket_manager.connect_all()
82+
for channel, symbols in self._subscriptions:
83+
for adapter in self.websocket_manager.adapters:
84+
await adapter.subscribe(channel, symbols)
85+
await self.websocket_manager.run(self._dispatch)
86+
87+
def start_streaming(self) -> None:
88+
"""Blocking call that starts the streaming event loop."""
89+
90+
asyncio.run(self._start())
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Monitoring utilities for streaming system."""
2+
3+
from .health_monitor import HealthMonitor
4+
from .metrics import Metrics
5+
6+
__all__ = ["HealthMonitor", "Metrics"]

0 commit comments

Comments
 (0)