Skip to content

Commit c8fd825

Browse files
authored
Merge pull request #35 from AKKI0511/implement-real-time-websocket-data-streaming
Add streaming gateway and WebSocket infrastructure
2 parents 5d4540d + 91e6ab6 commit c8fd825

23 files changed

+762
-2
lines changed

config/streaming.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
streaming:
2+
providers:
3+
- name: "alpaca"
4+
websocket_url: "wss://stream.data.alpaca.markets/v2/iex"
5+
auth_method: "api_key"
6+
subscriptions: ["trades", "quotes"]
7+
rate_limit:
8+
default_rate: 100
9+
burst_allowance: 50
10+
circuit_breaker:
11+
failure_threshold: 5
12+
timeout: 30
13+
buffer_size: 10000
14+
reconnect_attempts: 5
15+
health_check_interval: 30

poetry.lock

Lines changed: 40 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ dependencies = [
2525
"joblib (>=1.5.1,<2.0.0)",
2626
"pydantic (>=2.11.7,<3.0.0)",
2727
"alpha-vantage (>=3.0.0,<4.0.0)",
28-
"litellm (>=1.75.0,<2.0.0)"
28+
"litellm (>=1.75.0,<2.0.0)",
29+
"websockets (>=15.0.1,<16.0.0)",
30+
"structlog (>=25.4.0,<26.0.0)",
31+
"pybreaker (>=1.4.0,<2.0.0)",
32+
"asyncio-throttle (>=1.0.2,<2.0.0)",
33+
"prometheus-client (>=0.22.1,<0.23.0)"
2934
]
3035

3136
[tool.poetry]

quanttradeai/streaming/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Streaming infrastructure package."""
2+
3+
from .gateway import StreamingGateway
4+
from .auth_manager import AuthManager
5+
from .rate_limiter import AdaptiveRateLimiter
6+
from .connection_pool import ConnectionPool
7+
8+
__all__ = [
9+
"StreamingGateway",
10+
"AuthManager",
11+
"AdaptiveRateLimiter",
12+
"ConnectionPool",
13+
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Provider adapter implementations."""
2+
3+
from .base_adapter import DataProviderAdapter
4+
from .alpaca_adapter import AlpacaAdapter
5+
from .ib_adapter import IBAdapter
6+
7+
__all__ = ["DataProviderAdapter", "AlpacaAdapter", "IBAdapter"]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Alpaca data provider adapter."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Any, Dict, List
7+
8+
from .base_adapter import DataProviderAdapter
9+
10+
11+
@dataclass
12+
class AlpacaAdapter(DataProviderAdapter):
13+
"""Adapter for Alpaca's streaming API."""
14+
15+
name: str = "alpaca"
16+
17+
def _build_subscribe_message(
18+
self, channel: str, symbols: List[str]
19+
) -> Dict[str, Any]:
20+
return {"action": "subscribe", channel: symbols}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Base classes for data provider adapters."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
from abc import ABC, abstractmethod
7+
from dataclasses import dataclass, field
8+
from typing import Any, AsyncIterator, Dict, List, Optional
9+
10+
import websockets
11+
12+
from ..auth_manager import AuthManager
13+
from ..rate_limiter import AdaptiveRateLimiter
14+
15+
16+
@dataclass
17+
class DataProviderAdapter(ABC):
18+
"""Abstract adapter defining basic WebSocket operations."""
19+
20+
websocket_url: str
21+
name: str
22+
connection: Optional[websockets.WebSocketClientProtocol] = field(
23+
default=None, init=False
24+
)
25+
circuit_breaker: Any = field(default=None, init=False)
26+
rate_limiter: Optional[AdaptiveRateLimiter] = field(default=None, init=False)
27+
auth_manager: Optional[AuthManager] = field(default=None, init=False)
28+
29+
async def connect(self) -> None:
30+
"""Establish a WebSocket connection to the provider."""
31+
32+
headers = None
33+
if self.auth_manager:
34+
headers = await self.auth_manager.get_auth_headers()
35+
self.connection = await websockets.connect(
36+
self.websocket_url, extra_headers=headers
37+
)
38+
39+
async def subscribe(self, channel: str, symbols: List[str]) -> None:
40+
"""Send a subscription message for ``symbols`` on ``channel``."""
41+
42+
if self.connection is None:
43+
await self.connect()
44+
message = self._build_subscribe_message(channel, symbols)
45+
if self.rate_limiter:
46+
await self.rate_limiter.acquire()
47+
await self.connection.send(json.dumps(message))
48+
49+
async def listen(self) -> AsyncIterator[Dict[str, Any]]:
50+
"""Yield parsed JSON messages from the connection."""
51+
52+
if self.connection is None:
53+
raise RuntimeError("Connection not established")
54+
async for message in self.connection:
55+
yield json.loads(message)
56+
57+
@abstractmethod
58+
def _build_subscribe_message(
59+
self, channel: str, symbols: List[str]
60+
) -> Dict[str, Any]:
61+
"""Return provider specific subscription payload."""
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Interactive Brokers data provider adapter."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Any, Dict, List
7+
8+
from .base_adapter import DataProviderAdapter
9+
10+
11+
@dataclass
12+
class IBAdapter(DataProviderAdapter):
13+
"""Adapter for Interactive Brokers TWS/Gateway streaming API."""
14+
15+
name: str = "interactive_brokers"
16+
17+
def _build_subscribe_message(
18+
self, channel: str, symbols: List[str]
19+
) -> Dict[str, Any]:
20+
return {"action": "subscribe", "channel": channel, "symbols": symbols}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Authentication and authorization utilities for data providers."""
2+
3+
from __future__ import annotations
4+
5+
import os
6+
from dataclasses import dataclass
7+
from datetime import datetime, timedelta
8+
from typing import Dict, Optional
9+
10+
11+
@dataclass
12+
class AuthManager:
13+
"""Simple authentication manager with token refresh support."""
14+
15+
provider: str
16+
_token: Optional[str] = None
17+
_expires_at: datetime = datetime.fromtimestamp(0)
18+
19+
def _load_credentials(self) -> Dict[str, str]:
20+
"""Load credentials from environment variables."""
21+
key = os.getenv(f"{self.provider.upper()}_API_KEY", "")
22+
secret = os.getenv(f"{self.provider.upper()}_API_SECRET", "")
23+
return {"key": key, "secret": secret}
24+
25+
def _token_needs_refresh(self) -> bool:
26+
return datetime.utcnow() >= self._expires_at - timedelta(minutes=5)
27+
28+
async def _refresh_token(self) -> None:
29+
creds = self._load_credentials()
30+
# In production, call provider-specific auth endpoint.
31+
self._token = creds.get("key")
32+
self._expires_at = datetime.utcnow() + timedelta(hours=1)
33+
34+
async def get_auth_headers(self) -> Dict[str, str]:
35+
if self._token is None or self._token_needs_refresh():
36+
await self._refresh_token()
37+
return {"Authorization": f"Bearer {self._token}"}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Connection pooling and health management."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
from dataclasses import dataclass, field
7+
from typing import Any, Set, Callable, Awaitable
8+
9+
10+
@dataclass
11+
class ConnectionPool:
12+
"""Maintain a pool of reusable WebSocket connections."""
13+
14+
max_connections: int = 10
15+
_pool: asyncio.Queue = field(init=False)
16+
_active: Set[Any] = field(init=False, default_factory=set)
17+
18+
def __post_init__(self) -> None:
19+
self._pool = asyncio.Queue(maxsize=self.max_connections)
20+
21+
async def acquire_connection(self, factory: Callable[[], Awaitable[Any]]) -> Any:
22+
if not self._pool.empty():
23+
conn = await self._pool.get()
24+
elif len(self._active) < self.max_connections:
25+
conn = await factory()
26+
else:
27+
conn = await self._pool.get()
28+
self._active.add(conn)
29+
return conn
30+
31+
async def release(self, conn: Any) -> None:
32+
if conn in self._active:
33+
self._active.remove(conn)
34+
await self._pool.put(conn)
35+
36+
async def _ping_all_connections(self) -> None:
37+
for conn in list(self._active):
38+
if hasattr(conn, "ping"):
39+
await conn.ping()
40+
41+
async def health_check_loop(self, interval: int = 30) -> None:
42+
while True:
43+
await self._ping_all_connections()
44+
await asyncio.sleep(interval)

0 commit comments

Comments
 (0)