1- import asyncio
2- import json
3- import tempfile
4- import asyncio
5- import json
6- import tempfile
7- from unittest .mock import patch
8-
9- import yaml
10-
11- from quanttradeai .streaming import StreamingGateway
1+ import asyncio
2+ import json
3+ import tempfile
4+ import time
5+ from typing import Awaitable , Callable , Dict , List , Optional
6+ from unittest .mock import AsyncMock , patch
7+
8+ import yaml
9+
10+ from quanttradeai .streaming .monitoring import StreamingHealthMonitor
11+
12+ import pytest
13+ from quanttradeai .streaming import StreamingGateway
14+
15+
16+ class StubProviderMonitor :
17+ def __init__ (
18+ self ,
19+ * ,
20+ streaming_monitor : Optional [StreamingHealthMonitor ] = None ,
21+ ** _ : Dict ,
22+ ) -> None :
23+ self .streaming_monitor = streaming_monitor or StreamingHealthMonitor ()
24+ self .recovery_manager = self .streaming_monitor .recovery_manager
25+ self .registered : List [str ] = []
26+ self .failover_handlers : Dict [str , Callable [[], Awaitable [None ]]] = {}
27+ self .execute_calls : List [str ] = []
28+ self .record_success_calls : List [float ] = []
29+ self .record_failure_calls : List [str ] = []
30+ self .status_providers : Dict [str , Callable [[], object ]] = {}
31+
32+ def register_provider (
33+ self ,
34+ provider_name : str ,
35+ * ,
36+ failover_handler : Optional [Callable [[], Awaitable [None ]]] = None ,
37+ status_provider : Optional [Callable [[], object ]] = None ,
38+ ) -> None :
39+ self .streaming_monitor .register_connection (provider_name )
40+ self .registered .append (provider_name )
41+ if failover_handler is not None :
42+ self .failover_handlers [provider_name ] = failover_handler
43+ if status_provider is not None :
44+ self .status_providers [provider_name ] = status_provider
45+
46+ async def execute_with_health (
47+ self ,
48+ provider_name : str ,
49+ operation : Callable [[], Awaitable [object ]],
50+ * ,
51+ fallback : Optional [Callable [[], Awaitable [object ]]] = None ,
52+ ) -> object :
53+ self .execute_calls .append (provider_name )
54+ start = time .perf_counter ()
55+ try :
56+ result = await operation ()
57+ except Exception as exc :
58+ await self .record_failure (provider_name , exc )
59+ if fallback is not None :
60+ return await fallback ()
61+ raise
62+ latency_ms = (time .perf_counter () - start ) * 1000.0
63+ await self .record_success (provider_name , latency_ms )
64+ return result
65+
66+ async def record_success (
67+ self , provider_name : str , latency_ms : float , * , bytes_received : int = 0
68+ ) -> None :
69+ self .record_success_calls .append (latency_ms )
70+
71+ async def record_failure (self , provider_name : str , error : Exception ) -> None :
72+ self .record_failure_calls .append (provider_name )
1273
1374
1475class FakeConnection :
@@ -31,11 +92,11 @@ async def close(self):
3192 pass
3293
3394
34- def test_gateway_streaming ():
35- msg = json .dumps ({"type" : "trades" , "symbol" : "TEST" , "price" : 1 })
36-
37- async def connect (url , * _ , ** __ ):
38- return FakeConnection ([msg ])
95+ def test_gateway_streaming ():
96+ msg = json .dumps ({"type" : "trades" , "symbol" : "TEST" , "price" : 1 })
97+
98+ async def connect (url , * _ , ** __ ):
99+ return FakeConnection ([msg ])
39100
40101 async def run_test ():
41102 with patch ("websockets.connect" , new = connect ):
@@ -68,4 +129,102 @@ async def run_test():
68129 except Exception :
69130 pass
70131
71- asyncio .run (run_test ())
132+ asyncio .run (run_test ())
133+
134+
135+ @patch ("quanttradeai.streaming.gateway.ProviderHealthMonitor" , new = StubProviderMonitor )
136+ def test_gateway_registers_providers_and_failover (tmp_path ):
137+ cfg = {
138+ "streaming" : {
139+ "providers" : [
140+ {
141+ "name" : "alpaca" ,
142+ "websocket_url" : "ws://test" ,
143+ "auth_method" : "none" ,
144+ }
145+ ]
146+ }
147+ }
148+ config_file = tmp_path / "streaming.yaml"
149+ config_file .write_text (yaml .safe_dump (cfg ))
150+ gateway = StreamingGateway (str (config_file ))
151+ monitor = gateway .provider_monitor
152+ assert monitor .registered == ["alpaca" ]
153+ adapter = gateway .websocket_manager .adapters [0 ]
154+ gateway .websocket_manager ._connect_with_retry = AsyncMock ()
155+ failover = monitor .failover_handlers [adapter .name ]
156+ asyncio .run (failover ())
157+ gateway .websocket_manager ._connect_with_retry .assert_awaited_once ()
158+ _ , kwargs = gateway .websocket_manager ._connect_with_retry .await_args
159+ assert kwargs ["monitor" ] is monitor
160+
161+
162+ @patch ("quanttradeai.streaming.gateway.ProviderHealthMonitor" , new = StubProviderMonitor )
163+ def test_gateway_start_uses_provider_monitor (tmp_path ):
164+ cfg = {
165+ "streaming" : {
166+ "providers" : [
167+ {
168+ "name" : "alpaca" ,
169+ "websocket_url" : "ws://test" ,
170+ "auth_method" : "none" ,
171+ }
172+ ]
173+ }
174+ }
175+ config_file = tmp_path / "streaming.yaml"
176+ config_file .write_text (yaml .safe_dump (cfg ))
177+ gateway = StreamingGateway (str (config_file ))
178+ adapter = gateway .websocket_manager .adapters [0 ]
179+ gateway .subscribe_to_trades (["TEST" ], callback = lambda _ : None )
180+ adapter .subscribe = AsyncMock (return_value = None )
181+ gateway .websocket_manager .connect_all = AsyncMock ()
182+ gateway .websocket_manager .run = AsyncMock ()
183+ gateway .health_monitor .monitor_connection_health = AsyncMock ()
184+
185+ async def run_start ():
186+ await gateway ._start ()
187+
188+ asyncio .run (run_start ())
189+
190+ gateway .websocket_manager .connect_all .assert_awaited_once ()
191+ _ , kwargs = gateway .websocket_manager .connect_all .await_args
192+ assert kwargs ["monitor" ] is gateway .provider_monitor
193+ adapter .subscribe .assert_awaited_once ()
194+ assert gateway .provider_monitor .execute_calls .count ("alpaca" ) == len (
195+ gateway ._subscriptions
196+ )
197+
198+
199+ @patch ("quanttradeai.streaming.gateway.ProviderHealthMonitor" , new = StubProviderMonitor )
200+ def test_websocket_manager_reports_failures (tmp_path ):
201+ cfg = {
202+ "streaming" : {
203+ "providers" : [
204+ {
205+ "name" : "alpaca" ,
206+ "websocket_url" : "ws://test" ,
207+ "auth_method" : "none" ,
208+ }
209+ ]
210+ }
211+ }
212+ config_file = tmp_path / "streaming.yaml"
213+ config_file .write_text (yaml .safe_dump (cfg ))
214+ gateway = StreamingGateway (str (config_file ))
215+ monitor = gateway .provider_monitor
216+ manager = gateway .websocket_manager
217+ adapter = manager .adapters [0 ]
218+ manager .reconnect_attempts = 1
219+
220+ async def failing_connect ():
221+ raise RuntimeError ("boom" )
222+
223+ adapter .connect = AsyncMock (side_effect = failing_connect )
224+
225+ async def run_connect ():
226+ await manager ._connect_with_retry (adapter , monitor = monitor )
227+
228+ with pytest .raises (RuntimeError ):
229+ asyncio .run (run_connect ())
230+ assert monitor .record_failure_calls
0 commit comments