Skip to content

Commit 91e6ab6

Browse files
committed
test: cover streaming safeguards
1 parent ec5ccf5 commit 91e6ab6

File tree

14 files changed

+355
-18
lines changed

14 files changed

+355
-18
lines changed

config/streaming.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ streaming:
44
websocket_url: "wss://stream.data.alpaca.markets/v2/iex"
55
auth_method: "api_key"
66
subscriptions: ["trades", "quotes"]
7+
rate_limit:
8+
default_rate: 100
9+
burst_allowance: 50
10+
circuit_breaker:
11+
failure_threshold: 5
12+
timeout: 30
713
buffer_size: 10000
814
reconnect_attempts: 5
915
health_check_interval: 30

poetry.lock

Lines changed: 40 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ dependencies = [
2626
"pydantic (>=2.11.7,<3.0.0)",
2727
"alpha-vantage (>=3.0.0,<4.0.0)",
2828
"litellm (>=1.75.0,<2.0.0)",
29-
"websockets (>=15.0.1,<16.0.0)"
29+
"websockets (>=15.0.1,<16.0.0)",
30+
"structlog (>=25.4.0,<26.0.0)",
31+
"pybreaker (>=1.4.0,<2.0.0)",
32+
"asyncio-throttle (>=1.0.2,<2.0.0)",
33+
"prometheus-client (>=0.22.1,<0.23.0)"
3034
]
3135

3236
[tool.poetry]

quanttradeai/streaming/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
"""Streaming infrastructure package."""
22

33
from .gateway import StreamingGateway
4+
from .auth_manager import AuthManager
5+
from .rate_limiter import AdaptiveRateLimiter
6+
from .connection_pool import ConnectionPool
47

5-
__all__ = ["StreamingGateway"]
8+
__all__ = [
9+
"StreamingGateway",
10+
"AuthManager",
11+
"AdaptiveRateLimiter",
12+
"ConnectionPool",
13+
]

quanttradeai/streaming/adapters/base_adapter.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
import websockets
1111

12+
from ..auth_manager import AuthManager
13+
from ..rate_limiter import AdaptiveRateLimiter
14+
1215

1316
@dataclass
1417
class DataProviderAdapter(ABC):
@@ -19,18 +22,28 @@ class DataProviderAdapter(ABC):
1922
connection: Optional[websockets.WebSocketClientProtocol] = field(
2023
default=None, init=False
2124
)
25+
circuit_breaker: Any = field(default=None, init=False)
26+
rate_limiter: Optional[AdaptiveRateLimiter] = field(default=None, init=False)
27+
auth_manager: Optional[AuthManager] = field(default=None, init=False)
2228

2329
async def connect(self) -> None:
2430
"""Establish a WebSocket connection to the provider."""
2531

26-
self.connection = await websockets.connect(self.websocket_url)
32+
headers = None
33+
if self.auth_manager:
34+
headers = await self.auth_manager.get_auth_headers()
35+
self.connection = await websockets.connect(
36+
self.websocket_url, extra_headers=headers
37+
)
2738

2839
async def subscribe(self, channel: str, symbols: List[str]) -> None:
2940
"""Send a subscription message for ``symbols`` on ``channel``."""
3041

3142
if self.connection is None:
3243
await self.connect()
3344
message = self._build_subscribe_message(channel, symbols)
45+
if self.rate_limiter:
46+
await self.rate_limiter.acquire()
3447
await self.connection.send(json.dumps(message))
3548

3649
async def listen(self) -> AsyncIterator[Dict[str, Any]]:
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Authentication and authorization utilities for data providers."""
2+
3+
from __future__ import annotations
4+
5+
import os
6+
from dataclasses import dataclass
7+
from datetime import datetime, timedelta
8+
from typing import Dict, Optional
9+
10+
11+
@dataclass
12+
class AuthManager:
13+
"""Simple authentication manager with token refresh support."""
14+
15+
provider: str
16+
_token: Optional[str] = None
17+
_expires_at: datetime = datetime.fromtimestamp(0)
18+
19+
def _load_credentials(self) -> Dict[str, str]:
20+
"""Load credentials from environment variables."""
21+
key = os.getenv(f"{self.provider.upper()}_API_KEY", "")
22+
secret = os.getenv(f"{self.provider.upper()}_API_SECRET", "")
23+
return {"key": key, "secret": secret}
24+
25+
def _token_needs_refresh(self) -> bool:
26+
return datetime.utcnow() >= self._expires_at - timedelta(minutes=5)
27+
28+
async def _refresh_token(self) -> None:
29+
creds = self._load_credentials()
30+
# In production, call provider-specific auth endpoint.
31+
self._token = creds.get("key")
32+
self._expires_at = datetime.utcnow() + timedelta(hours=1)
33+
34+
async def get_auth_headers(self) -> Dict[str, str]:
35+
if self._token is None or self._token_needs_refresh():
36+
await self._refresh_token()
37+
return {"Authorization": f"Bearer {self._token}"}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Connection pooling and health management."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
from dataclasses import dataclass, field
7+
from typing import Any, Set, Callable, Awaitable
8+
9+
10+
@dataclass
11+
class ConnectionPool:
12+
"""Maintain a pool of reusable WebSocket connections."""
13+
14+
max_connections: int = 10
15+
_pool: asyncio.Queue = field(init=False)
16+
_active: Set[Any] = field(init=False, default_factory=set)
17+
18+
def __post_init__(self) -> None:
19+
self._pool = asyncio.Queue(maxsize=self.max_connections)
20+
21+
async def acquire_connection(self, factory: Callable[[], Awaitable[Any]]) -> Any:
22+
if not self._pool.empty():
23+
conn = await self._pool.get()
24+
elif len(self._active) < self.max_connections:
25+
conn = await factory()
26+
else:
27+
conn = await self._pool.get()
28+
self._active.add(conn)
29+
return conn
30+
31+
async def release(self, conn: Any) -> None:
32+
if conn in self._active:
33+
self._active.remove(conn)
34+
await self._pool.put(conn)
35+
36+
async def _ping_all_connections(self) -> None:
37+
for conn in list(self._active):
38+
if hasattr(conn, "ping"):
39+
await conn.ping()
40+
41+
async def health_check_loop(self, interval: int = 30) -> None:
42+
while True:
43+
await self._ping_all_connections()
44+
await asyncio.sleep(interval)

quanttradeai/streaming/gateway.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .adapters.ib_adapter import IBAdapter
1515
from .processors import MessageProcessor
1616
from .monitoring.metrics import Metrics
17+
from .logging import logger
1718

1819
AdapterMap = {
1920
"alpaca": AlpacaAdapter,
@@ -51,7 +52,12 @@ def __post_init__(self) -> None:
5152
if adapter_cls is None:
5253
raise ValueError(f"Unknown provider: {name}")
5354
adapter = adapter_cls(websocket_url=url)
54-
self.websocket_manager.add_adapter(adapter)
55+
self.websocket_manager.add_adapter(
56+
adapter,
57+
circuit_breaker_cfg=provider_cfg.get("circuit_breaker", {}),
58+
rate_limit_cfg=provider_cfg.get("rate_limit"),
59+
auth_method=provider_cfg.get("auth_method", "api_key"),
60+
)
5561

5662
# Subscription API -------------------------------------------------
5763
def subscribe_to_trades(self, symbols: List[str], callback: Callback) -> None:
@@ -68,10 +74,14 @@ def subscribe_to_quotes(self, symbols: List[str], callback: Callback) -> None:
6874

6975
# Runtime ----------------------------------------------------------
7076
async def _dispatch(self, provider: str, message: Dict) -> None:
71-
self.metrics.increment()
77+
symbol = message.get("symbol", "")
78+
self.metrics.record_message(provider, symbol)
7279
processed = self.message_processor.process(message)
7380
await self.buffer.put(processed)
7481
msg_type = message.get("type", "trades")
82+
logger.debug(
83+
"dispatch_message", provider=provider, type=msg_type, symbol=symbol
84+
)
7585
for cb in self._callbacks.get(msg_type, []):
7686
res = cb(processed)
7787
if asyncio.iscoroutine(res):

quanttradeai/streaming/logging.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Structured logging utilities for streaming infrastructure."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
import structlog
7+
8+
9+
def configure_logging(level: str = "INFO") -> None:
10+
"""Configure structlog for JSON formatted logs.
11+
12+
Args:
13+
level: Minimum log level.
14+
"""
15+
structlog.configure(
16+
processors=[
17+
structlog.processors.TimeStamper(fmt="iso"),
18+
structlog.processors.add_log_level,
19+
structlog.contextvars.merge_contextvars,
20+
structlog.processors.JSONRenderer(),
21+
],
22+
wrapper_class=structlog.make_filtering_bound_logger(getattr(logging, level)),
23+
)
24+
25+
26+
configure_logging()
27+
logger = structlog.get_logger()
Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,30 @@
1-
"""Streaming metrics collection."""
1+
"""Prometheus metrics for streaming infrastructure."""
22

33
from __future__ import annotations
44

55
from dataclasses import dataclass
66

7+
from prometheus_client import Counter, Gauge, Histogram
8+
9+
10+
messages_processed = Counter(
11+
"websocket_messages_total", "Total processed messages", ["provider", "symbol"]
12+
)
13+
connection_latency = Histogram(
14+
"websocket_connection_latency_seconds", "Connection establishment time"
15+
)
16+
active_connections = Gauge("websocket_active_connections", "Current active connections")
17+
718

819
@dataclass
920
class Metrics:
10-
"""Simple in-memory metrics tracker."""
21+
"""Convenience wrapper around Prometheus metrics."""
1122

12-
messages_received: int = 0
23+
def record_message(self, provider: str, symbol: str) -> None:
24+
messages_processed.labels(provider=provider, symbol=symbol).inc()
1325

14-
def increment(self) -> None:
15-
"""Increment the received message counter."""
26+
def record_connection_latency(self, seconds: float) -> None:
27+
connection_latency.observe(seconds)
1628

17-
self.messages_received += 1
29+
def set_active_connections(self, count: int) -> None:
30+
active_connections.set(count)

0 commit comments

Comments
 (0)