@@ -87,6 +87,8 @@ async def _ensure_ws(self):
8787 raise ConnectionError ("Max Retries Exceeded" )
8888
8989 async def subscribe (self , channels ):
90+ if isinstance (channels , str ):
91+ channels = [channels ]
9092 if len (channels ) > 0 :
9193 await self ._ensure_ws ()
9294 self ._streams |= set (channels )
@@ -98,9 +100,15 @@ async def subscribe(self, channels):
98100 }))
99101
100102 async def unsubscribe (self , channels ):
101- # Currently our streams don't support unsubscribe
102- # not as useful with our feeds
103- pass
103+ if isinstance (channels , str ):
104+ channels = [channels ]
105+ if len (channels ) > 0 :
106+ await self ._ws .send (json .dumps ({
107+ 'action' : 'unlisten' ,
108+ 'data' : {
109+ 'streams' : channels ,
110+ }
111+ }))
104112
105113 async def close (self ):
106114 if self ._consume_task :
@@ -158,15 +166,31 @@ def __init__(
158166 key_id = None ,
159167 secret_key = None ,
160168 base_url = None ,
161- data_url = None ):
169+ data_url = None ,
170+ data_stream = None ):
162171 _key_id , _secret_key , _ = get_credentials (key_id , secret_key )
163172 _base_url = base_url or get_base_url ()
164173 _data_url = data_url or get_data_url ()
174+ if data_stream is not None :
175+ if data_stream in ('alpacadatav1' , 'polygon' ):
176+ _data_stream = data_stream
177+ else :
178+ raise ValueError ('invalid data_stream name {}' .format (
179+ data_stream ))
180+ else :
181+ _data_stream = 'alpacadatav1'
182+ self ._data_stream = _data_stream
165183
166184 self .trading_ws = _StreamConn (_key_id , _secret_key , _base_url )
167- self .data_ws = _StreamConn (_key_id , _secret_key , _data_url )
168- self .polygon = polygon .StreamConn (
169- _key_id + '-staging' if 'staging' in _base_url else _key_id )
185+
186+ if self ._data_stream == 'polygon' :
187+ self .data_ws = polygon .StreamConn (
188+ _key_id + '-staging' if 'staging' in _base_url else _key_id )
189+ self ._data_prefixes = (('Q.' , 'T.' , 'A.' , 'AM.' ))
190+ else :
191+ self .data_ws = _StreamConn (_key_id , _secret_key , _data_url )
192+ self ._data_prefixes = (
193+ ('Q.' , 'T.' , 'AM.' , 'polyfeed/' , 'alpacadatav1/' ))
170194
171195 self ._handlers = {}
172196 self ._handler_symbols = {}
@@ -191,34 +215,41 @@ async def _ensure_ws(self, conn):
191215 async def subscribe (self , channels ):
192216 '''Start subscribing to channels.
193217 If the necessary connection isn't open yet, it opens now.
218+ This may raise ValueError if a channel is not recognized.
194219 '''
195- trading_channels , data_channels , polygon_channels = [], [], []
220+ trading_channels , data_channels = [], []
221+
196222 for c in channels :
197- if c .startswith (('Q.' , 'T.' , 'A.' , 'AM.' ,)):
198- polygon_channels .append (c )
199- elif c in ('trade_updates' , 'account_updates' ):
223+ if c in ('trade_updates' , 'account_updates' ):
200224 trading_channels .append (c )
201- else :
225+ elif c . startswith ( self . _data_prefixes ) :
202226 data_channels .append (c )
227+ else :
228+ raise ValueError (
229+ ('unknown channel {} (you may need to specify ' +
230+ 'the right data_stream)' ).format (c ))
203231
204232 if trading_channels :
205233 await self ._ensure_ws (self .trading_ws )
206234 await self .trading_ws .subscribe (trading_channels )
207235 if data_channels :
208236 await self ._ensure_ws (self .data_ws )
209237 await self .data_ws .subscribe (data_channels )
210- if polygon_channels :
211- await self ._ensure_ws (self .polygon )
212- await self .polygon .subscribe (polygon_channels )
213238
214239 async def unsubscribe (self , channels ):
215240 '''Handle unsubscribing from channels.'''
216- polygon_channels = [
241+
242+ data_prefixes = ('Q.' , 'T.' , 'AM.' )
243+ if self ._data_stream == 'polygon' :
244+ data_prefixes = ('Q.' , 'T.' , 'A.' , 'AM.' )
245+
246+ data_channels = [
217247 c for c in channels
218- if c .startswith (( 'Q.' , 'T.' , 'A.' , 'AM.' ,) )
248+ if c .startswith (data_prefixes )
219249 ]
220- if polygon_channels :
221- await self .polygon .unsubscribe (polygon_channels )
250+
251+ if data_channels :
252+ await self .data_ws .unsubscribe (data_channels )
222253
223254 def run (self , initial_channels = []):
224255 '''Run forever and block until exception is raised.
@@ -242,9 +273,6 @@ async def close(self):
242273 if self .data_ws is not None :
243274 await self .data_ws .close ()
244275 self .data_ws = None
245- if self .polygon is not None :
246- await self .polygon .close ()
247- self .polygon = None
248276
249277 def on (self , channel_pat , symbols = None ):
250278 def decorator (func ):
@@ -265,8 +293,6 @@ def register(self, channel_pat, func, symbols=None):
265293 self .trading_ws .register (channel_pat , func , symbols )
266294 if self .data_ws :
267295 self .data_ws .register (channel_pat , func , symbols )
268- if self .polygon :
269- self .polygon .register (channel_pat , func , symbols )
270296
271297 def deregister (self , channel_pat ):
272298 if isinstance (channel_pat , str ):
@@ -278,5 +304,3 @@ def deregister(self, channel_pat):
278304 self .trading_ws .deregister (channel_pat )
279305 if self .data_ws :
280306 self .data_ws .deregister (channel_pat )
281- if self .polygon :
282- self .polygon .deregister (channel_pat )
0 commit comments