Skip to content

Commit 6132575

Browse files
authored
Add news api (#555)
* added news stream * added news historical * fixed flake8 format * Added rest api test for news * news api default limit and page limit, other pr comments * pr comments * pr comments
1 parent f36d6c4 commit 6132575

File tree

4 files changed

+233
-9
lines changed

4 files changed

+233
-9
lines changed

alpaca_trade_api/entity_v2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,16 @@ def __init__(self, raw):
225225
self[k] = _convert_or_none(QuoteV2, v)
226226

227227

228+
class NewsV2(Entity):
229+
def __init__(self, raw):
230+
super().__init__(raw)
231+
232+
233+
class NewsListV2(list):
234+
def __init__(self, raw):
235+
super().__init__([NewsV2(o) for o in raw])
236+
237+
228238
def _convert_or_none(entityType, value):
229239
if value:
230240
return entityType(value)

alpaca_trade_api/rest.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
)
2121
from .entity_v2 import (
2222
BarV2, BarsV2, LatestBarsV2, LatestQuotesV2, LatestTradesV2,
23-
SnapshotV2, SnapshotsV2, TradesV2, TradeV2, QuotesV2, QuoteV2)
23+
SnapshotV2, SnapshotsV2, TradesV2, TradeV2, QuotesV2, QuoteV2,
24+
NewsV2, NewsListV2
25+
)
2426

2527
logger = logging.getLogger(__name__)
2628
Positions = List[Position]
@@ -32,8 +34,10 @@
3234
TradeIterator = Iterator[Union[Trade, dict]]
3335
QuoteIterator = Iterator[Union[Quote, dict]]
3436
BarIterator = Iterator[Union[Bar, dict]]
37+
NewsIterator = Iterator[Union[NewsV2, dict]]
3538

3639
DATA_V2_MAX_LIMIT = 10000 # max items per api call
40+
NEWS_MAX_LIMIT = 50 # max items per api call
3741

3842

3943
class RetryException(Exception):
@@ -129,6 +133,14 @@ def validate(amount: int, unit: TimeFrameUnit):
129133
TimeFrame.Day = TimeFrame(1, TimeFrameUnit.Day)
130134

131135

136+
class Sort(Enum):
137+
Asc = "asc"
138+
Desc = "desc"
139+
140+
def __str__(self):
141+
return self.value
142+
143+
132144
class REST(object):
133145
def __init__(self,
134146
key_id: str = None,
@@ -609,27 +621,34 @@ def _data_get(self,
609621
symbol_or_symbols: Union[str, List[str]],
610622
api_version: str = 'v2',
611623
endpoint_base: str = 'stocks',
624+
resp_grouped_by_symbol: Optional[bool] = None,
625+
page_limit: int = DATA_V2_MAX_LIMIT,
612626
**kwargs):
613627
page_token = None
614628
total_items = 0
615629
limit = kwargs.get('limit')
630+
if resp_grouped_by_symbol is None:
631+
resp_grouped_by_symbol = not isinstance(symbol_or_symbols, str)
616632
while True:
617633
actual_limit = None
618634
if limit:
619-
actual_limit = min(int(limit) - total_items, DATA_V2_MAX_LIMIT)
635+
actual_limit = min(int(limit) - total_items, page_limit)
620636
if actual_limit < 1:
621637
break
622638
data = kwargs
623639
data['limit'] = actual_limit
624640
data['page_token'] = page_token
625-
if isinstance(symbol_or_symbols, str):
626-
path = f'/{endpoint_base}/{symbol_or_symbols}/{endpoint}'
641+
path = f'/{endpoint_base}'
642+
if isinstance(symbol_or_symbols, str) and symbol_or_symbols:
643+
path += f'/{symbol_or_symbols}'
627644
else:
628-
path = f'/{endpoint_base}/{endpoint}'
629645
data['symbols'] = ','.join(symbol_or_symbols)
646+
if endpoint:
647+
path += f'/{endpoint}'
630648
resp = self.data_get(path, data=data, api_version=api_version)
631-
if isinstance(symbol_or_symbols, str):
632-
for item in resp.get(endpoint, []) or []:
649+
if not resp_grouped_by_symbol:
650+
k = endpoint or endpoint_base
651+
for item in resp.get(k, []) or []:
633652
yield item
634653
total_items += 1
635654
else:
@@ -893,6 +912,50 @@ def get_crypto_snapshot(self, symbol: str, exchange: str) -> SnapshotV2:
893912
api_version='v1beta1')
894913
return self.response_wrapper(resp, SnapshotV2)
895914

915+
def get_news_iter(self,
916+
symbol: Optional[Union[str, List[str]]] = None,
917+
start: Optional[str] = None,
918+
end: Optional[str] = None,
919+
limit: int = 10,
920+
sort: Sort = Sort.Desc,
921+
include_content: bool = False,
922+
exclude_contentless: bool = False,
923+
raw=False) -> NewsIterator:
924+
symbol = symbol or []
925+
# Avoid passing symbol as path param
926+
if isinstance(symbol, str):
927+
symbol = [symbol]
928+
news = self._data_get('', symbol,
929+
api_version='v1beta1', endpoint_base='news',
930+
start=start, end=end, limit=limit, sort=sort,
931+
include_content=include_content,
932+
exclude_contentless=exclude_contentless,
933+
resp_grouped_by_symbol=False,
934+
page_limit=NEWS_MAX_LIMIT)
935+
for n in news:
936+
if raw:
937+
yield n
938+
else:
939+
yield self.response_wrapper(n, NewsV2)
940+
941+
def get_news(self,
942+
symbol: Optional[Union[str, List[str]]] = None,
943+
start: Optional[str] = None,
944+
end: Optional[str] = None,
945+
limit: int = 10,
946+
sort: Sort = Sort.Desc,
947+
include_content: bool = False,
948+
exclude_contentless: bool = False,
949+
950+
) -> NewsListV2:
951+
news = list(self.get_news_iter(symbol=symbol,
952+
start=start, end=end,
953+
limit=limit, sort=sort,
954+
include_content=include_content,
955+
exclude_contentless=exclude_contentless,
956+
raw=True))
957+
return NewsListV2(news)
958+
896959
def get_clock(self) -> Clock:
897960
resp = self.get('/clock')
898961
return self.response_wrapper(resp, Clock)

alpaca_trade_api/stream.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
LULDV2,
3030
CancelErrorV2,
3131
CorrectionV2,
32+
NewsV2,
3233
)
3334

3435
log = logging.getLogger(__name__)
@@ -432,6 +433,63 @@ def __init__(self,
432433
self._name = 'crypto data'
433434

434435

436+
class NewsDataStream(_DataStream):
437+
def __init__(self,
438+
key_id: str,
439+
secret_key: str,
440+
base_url: URL,
441+
raw_data: bool):
442+
self._key_id = key_id
443+
self._secret_key = secret_key
444+
base_url = re.sub(r'^http', 'ws', base_url)
445+
endpoint = base_url + '/v1beta1/news'
446+
super().__init__(endpoint=endpoint,
447+
key_id=key_id,
448+
secret_key=secret_key,
449+
raw_data=raw_data,
450+
)
451+
self._handlers = {
452+
'news': {},
453+
}
454+
self._name = 'news data'
455+
456+
def _cast(self, msg_type, msg):
457+
result = super()._cast(msg_type, msg)
458+
if not self._raw_data:
459+
if msg_type == 'n':
460+
result = NewsV2(msg)
461+
return result
462+
463+
async def _dispatch(self, msg):
464+
msg_type = msg.get('T')
465+
symbol = msg.get('S')
466+
if msg_type == 'n':
467+
handler = self._handlers['news'].get(
468+
symbol, self._handlers['news'].get('*', None))
469+
if handler:
470+
await handler(self._cast(msg_type, msg))
471+
else:
472+
await super()._dispatch(msg)
473+
474+
async def _unsubscribe(self, news=()):
475+
if news:
476+
await self._ws.send(
477+
msgpack.packb({
478+
'action': 'unsubscribe',
479+
'news': news,
480+
}))
481+
482+
def subscribe_news(self, handler, *symbols):
483+
self._subscribe(handler, symbols, self._handlers['news'])
484+
485+
def unsubscribe_news(self, *symbols):
486+
if self._running:
487+
asyncio.get_event_loop().run_until_complete(
488+
self._unsubscribe(news=symbols))
489+
for symbol in symbols:
490+
del self._handlers['news'][symbol]
491+
492+
435493
class TradingStream:
436494
def __init__(self,
437495
key_id: str,
@@ -588,6 +646,10 @@ def __init__(self,
588646
self._data_steam_url,
589647
raw_data,
590648
crypto_exchanges)
649+
self._news_ws = NewsDataStream(self._key_id,
650+
self._secret_key,
651+
self._data_steam_url,
652+
raw_data)
591653

592654
def subscribe_trade_updates(self, handler):
593655
self._trading_ws.subscribe_trade_updates(handler)
@@ -634,6 +696,9 @@ def subscribe_crypto_bars(self, handler, *symbols):
634696
def subscribe_crypto_daily_bars(self, handler, *symbols):
635697
self._crypto_ws.subscribe_daily_bars(handler, *symbols)
636698

699+
def subscribe_news(self, handler, *symbols):
700+
self._news_ws.subscribe_news(handler, *symbols)
701+
637702
def on_trade_update(self, func):
638703
self.subscribe_trade_updates(func)
639704
return func
@@ -722,6 +787,13 @@ def decorator(func):
722787

723788
return decorator
724789

790+
def on_news(self, *symbols):
791+
def decorator(func):
792+
self.subscribe_news(func, *symbols)
793+
return func
794+
795+
return decorator
796+
725797
def unsubscribe_trades(self, *symbols):
726798
self._data_ws.unsubscribe_trades(*symbols)
727799
self._data_ws.unregister_handler("cancelErrors", *symbols)
@@ -754,10 +826,14 @@ def unsubscribe_crypto_bars(self, *symbols):
754826
def unsubscribe_crypto_daily_bars(self, *symbols):
755827
self._crypto_ws.unsubscribe_daily_bars(*symbols)
756828

829+
def unsubscribe_news(self, *symbols):
830+
self._news_ws.unsubscribe_news(*symbols)
831+
757832
async def _run_forever(self):
758833
await asyncio.gather(self._trading_ws._run_forever(),
759834
self._data_ws._run_forever(),
760-
self._crypto_ws._run_forever())
835+
self._crypto_ws._run_forever(),
836+
self._news_ws._run_forever())
761837

762838
def run(self):
763839
loop = asyncio.get_event_loop()
@@ -780,12 +856,16 @@ async def stop_ws(self):
780856
if self._crypto_ws:
781857
await self._crypto_ws.stop_ws()
782858

859+
if self._news_ws:
860+
await self._news_ws.stop_ws()
861+
783862
def is_open(self):
784863
"""
785864
Checks if either of the websockets is open
786865
:return:
787866
"""
788-
open_ws = self._trading_ws._ws or self._data_ws._ws or self._crypto_ws._ws # noqa
867+
open_ws = (self._trading_ws._ws or self._data_ws._ws
868+
or self._crypto_ws._ws or self._news_ws) # noqa
789869
if open_ws:
790870
return True
791871
return False

tests/test_rest.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,77 @@ def test_data(reqmock):
788788
assert msft_snapshot.prev_daily_bar is None
789789
assert snapshots.get('INVALID') is None
790790

791+
# News
792+
reqmock.get(
793+
'https://data.alpaca.markets/v1beta1/news' +
794+
'?symbols=AAPL,TSLA&limit=2',
795+
text='''
796+
{
797+
"news": [
798+
{
799+
"id": 24994117,
800+
"headline": "'Tesla Approved...",
801+
"author": "Benzinga Newsdesk",
802+
"created_at": "2022-01-11T13:50:47Z",
803+
"updated_at": "2022-01-11T13:50:47Z",
804+
"summary": "",
805+
"url": "https://www.benzinga.com/news/some/path",
806+
"images": [],
807+
"symbols": [
808+
"TSLA"
809+
],
810+
"source": "benzinga"
811+
},
812+
{
813+
"id": 24993189,
814+
"headline": "Dogecoin Is Down 80% ...",
815+
"author": "Samyuktha Sriram",
816+
"created_at": "2022-01-11T13:49:40Z",
817+
"updated_at": "2022-01-11T13:49:41Z",
818+
"summary": "Popular meme-based cryptocurrency...",
819+
"url": "https://www.benzinga.com/markets/some/path",
820+
"images": [
821+
{
822+
"size": "large",
823+
"url": "https://cdn.benzinga.com/files/some.jpeg"
824+
},
825+
{
826+
"size": "small",
827+
"url": "https://cdn.benzinga.com/files/some.jpeg"
828+
},
829+
{
830+
"size": "thumb",
831+
"url": "https://cdn.benzinga.com/files/some.jpeg"
832+
}
833+
],
834+
"symbols": [
835+
"BTCUSD",
836+
"DOGEUSD",
837+
"SHIBUSD",
838+
"TSLA"
839+
],
840+
"source": "benzinga"
841+
}
842+
]
843+
}
844+
'''
845+
)
846+
news = api.get_news(['AAPL', 'TSLA'], limit=2)
847+
assert len(news) == 2
848+
first = news[0]
849+
assert first is not None
850+
assert first.author == 'Benzinga Newsdesk'
851+
assert 'TSLA' in first.symbols
852+
assert first.source == 'benzinga'
853+
assert type(first) == tradeapi.entity_v2.NewsV2
854+
second = news[1]
855+
assert second is not None
856+
assert second.headline != ''
857+
assert type(second.images) == list
858+
assert 'TSLA' in second.symbols
859+
assert second.source == 'benzinga'
860+
assert type(second) == tradeapi.entity_v2.NewsV2
861+
791862

792863
def test_timeframe(reqmock):
793864
# Custom timeframe: Minutes

0 commit comments

Comments
 (0)