Skip to content

Commit 1c9a6f1

Browse files
authored
V1.0.2 (#35)
* update version * added test for predictor * refactor alpaca.py * update requirements.txt
1 parent bd5d2e6 commit 1c9a6f1

File tree

5 files changed

+83
-27
lines changed

5 files changed

+83
-27
lines changed

py_alpaca_api/alpaca.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,50 +14,49 @@ class PyAlpacaApi:
1414
def __init__(self, api_key: str, api_secret: str, api_paper: bool = True):
1515
"""
1616
Initializes an instance of the Alpaca class.
17-
1817
Args:
1918
api_key (str): The API key for accessing the Alpaca API.
2019
api_secret (str): The API secret for accessing the Alpaca API.
2120
api_paper (bool, optional): Specifies whether to use the Alpaca paper trading API.
2221
Defaults to True.
23-
2422
Raises:
2523
ValueError: If the API key or API secret is not provided.
2624
"""
27-
if not api_key:
28-
raise ValueError("API Key is required")
29-
if not api_secret:
30-
raise ValueError("API Secret is required")
3125

32-
# Set the API Key and Secret
26+
# Check API Key and Secret
27+
self._validate_api_key_and_secret(api_key, api_secret)
28+
29+
# Set Headers
3330
self.headers = {
3431
"APCA-API-KEY-ID": api_key,
3532
"APCA-API-SECRET-KEY": api_secret,
3633
}
3734

38-
# Set the API URL's
39-
if api_paper:
40-
self.trade_url = "https://paper-api.alpaca.markets/v2"
41-
else:
42-
self.trade_url = "https://api.alpaca.markets/v2"
43-
35+
# Set URLs
4436
self.data_url = "https://data.alpaca.markets/v2"
37+
self.trade_url = self._set_trade_url(api_paper)
38+
39+
# Initialize Components
40+
self._initialize_components()
41+
42+
@staticmethod
43+
def _validate_api_key_and_secret(api_key: str, api_secret: str):
44+
if not api_key:
45+
raise ValueError("API Key is required")
46+
if not api_secret:
47+
raise ValueError("API Secret is required")
48+
49+
@staticmethod
50+
def _set_trade_url(api_paper: bool):
51+
return "https://paper-api.alpaca.markets/v2" if api_paper else "https://api.alpaca.markets/v2"
4552

53+
def _initialize_components(self):
4654
self.account = Account(trade_url=self.trade_url, headers=self.headers)
4755
self.asset = Asset(trade_url=self.trade_url, headers=self.headers)
4856
self.history = History(data_url=self.data_url, headers=self.headers, asset=self.asset)
49-
self.position = Position(
50-
trade_url=self.trade_url,
51-
headers=self.headers,
52-
account=self.account,
53-
)
57+
self.position = Position(trade_url=self.trade_url, headers=self.headers, account=self.account)
5458
self.order = Order(trade_url=self.trade_url, headers=self.headers)
5559
self.market = Market(trade_url=self.trade_url, headers=self.headers)
5660
self.watchlist = Watchlist(trade_url=self.trade_url, headers=self.headers)
57-
self.screener = Screener(
58-
data_url=self.data_url,
59-
headers=self.headers,
60-
asset=self.asset,
61-
market=self.market,
62-
)
61+
self.screener = Screener(data_url=self.data_url, headers=self.headers, asset=self.asset, market=self.market)
6362
self.predictor = Predictor(history=self.history, screener=self.screener)

py_alpaca_api/src/predictor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
yesterday = pendulum.now().subtract(days=1).format("YYYY-MM-DD")
1212
four_years_ago = pendulum.now().subtract(years=2).format("YYYY-MM-DD")
1313

14-
1514
logger = logging.getLogger("cmdstanpy")
1615
logger.disabled = True
1716
logger.propagate = False

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "py-alpaca-api"
3-
version = "1.0.1"
3+
version = "1.0.2"
44
description = "Python package, for communicating with Alpaca Markets REST API."
55
authors = ["TexasCoding <[email protected]>"]
66
license = "MIT"

requirements.txt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,29 @@
1-
certifi==2024.2.2 ; python_version >= "3.12" and python_version < "4.0"
1+
certifi==2024.6.2 ; python_version >= "3.12" and python_version < "4.0"
22
charset-normalizer==3.3.2 ; python_version >= "3.12" and python_version < "4.0"
3+
cmdstanpy==1.2.2 ; python_version >= "3.12" and python_version < "4.0"
4+
colorama==0.4.6 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Windows"
5+
contourpy==1.2.1 ; python_version >= "3.12" and python_version < "4.0"
6+
cycler==0.12.1 ; python_version >= "3.12" and python_version < "4.0"
7+
fonttools==4.53.0 ; python_version >= "3.12" and python_version < "4.0"
8+
holidays==0.49 ; python_version >= "3.12" and python_version < "4.0"
39
idna==3.7 ; python_version >= "3.12" and python_version < "4.0"
10+
importlib-resources==6.4.0 ; python_version >= "3.12" and python_version < "4.0"
11+
kiwisolver==1.4.5 ; python_version >= "3.12" and python_version < "4.0"
12+
matplotlib==3.9.0 ; python_version >= "3.12" and python_version < "4.0"
413
numpy==1.26.4 ; python_version >= "3.12" and python_version < "4.0"
14+
packaging==24.0 ; python_version >= "3.12" and python_version < "4.0"
515
pandas==2.2.2 ; python_version >= "3.12" and python_version < "4.0"
616
pendulum==3.0.0 ; python_version >= "3.12" and python_version < "4.0"
17+
pillow==10.3.0 ; python_version >= "3.12" and python_version < "4.0"
18+
plotly==5.22.0 ; python_version >= "3.12" and python_version < "4.0"
19+
prophet==1.1.5 ; python_version >= "3.12" and python_version < "4.0"
20+
pyparsing==3.1.2 ; python_version >= "3.12" and python_version < "4.0"
721
python-dateutil==2.9.0.post0 ; python_version >= "3.12" and python_version < "4.0"
822
pytz==2024.1 ; python_version >= "3.12" and python_version < "4.0"
923
requests==2.32.3 ; python_version >= "3.12" and python_version < "4.0"
1024
six==1.16.0 ; python_version >= "3.12" and python_version < "4.0"
25+
stanio==0.5.0 ; python_version >= "3.12" and python_version < "4.0"
26+
tenacity==8.3.0 ; python_version >= "3.12" and python_version < "4.0"
27+
tqdm==4.66.4 ; python_version >= "3.12" and python_version < "4.0"
1128
tzdata==2024.1 ; python_version >= "3.12" and python_version < "4.0"
1229
urllib3==2.2.1 ; python_version >= "3.12" and python_version < "4.0"

tests/test_predictor.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import sys
2+
from unittest.mock import MagicMock
3+
4+
import pandas as pd
5+
import pytest
6+
7+
from py_alpaca_api.src.predictor import Predictor
8+
9+
sys.path.append("py_alpaca_api/src")
10+
11+
12+
@pytest.fixture(scope="module")
13+
def predictor():
14+
history = MagicMock()
15+
screener = MagicMock()
16+
return Predictor(history, screener)
17+
18+
19+
def test_get_losers_to_gainers_with_exception_handling(predictor):
20+
predictor.get_stock_data = MagicMock(return_value=pd.DataFrame())
21+
predictor.train_prophet_model = MagicMock(return_value=MagicMock())
22+
predictor.generate_forecast = MagicMock(side_effect=Exception("Random error"))
23+
assert predictor.get_losers_to_gainers() == []
24+
25+
26+
def test_get_losers_to_gainers_when_forecast_price_is_less_than_previous_price(predictor):
27+
ticker = "AAPL"
28+
predictor.get_stock_data = MagicMock(return_value=pd.DataFrame())
29+
predictor.train_prophet_model = MagicMock(return_value=MagicMock())
30+
predictor.screener.losers = MagicMock(return_value=pd.DataFrame({"symbol": [ticker], "price": [10]}))
31+
predictor.generate_forecast = MagicMock(return_value=5)
32+
assert predictor.get_losers_to_gainers() == []
33+
34+
35+
def test_get_losers_to_gainers_when_forecast_price_is_greater_than_previous_price(predictor):
36+
ticker = "AAPL"
37+
predictor.get_stock_data = MagicMock(return_value=pd.DataFrame())
38+
predictor.train_prophet_model = MagicMock(return_value=MagicMock())
39+
predictor.screener.losers = MagicMock(return_value=pd.DataFrame({"symbol": [ticker], "price": [10.00]}))
40+
predictor.generate_forecast = MagicMock(return_value=20.00)
41+
assert predictor.get_losers_to_gainers() == [ticker]

0 commit comments

Comments
 (0)