Skip to content

Commit b9c5be1

Browse files
Added demo for correlation matrix (#441)
* Added demo for correlation matrix
1 parent 384db11 commit b9c5be1

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
This script computes and visualizes the correlation matrix of a selected set of
3+
stocks using Polygon's API. This script is for educational purposes only and is
4+
not intended to provide investment advice. The examples provided analyze the
5+
correlation between different stocks from diverse sectors, as well as within
6+
specific sectors.
7+
8+
Before running this script, there are 4 prerequisites:
9+
10+
1) Dependencies: Ensure that the following Python libraries are installed in
11+
your environment:
12+
- pandas
13+
- numpy
14+
- seaborn
15+
- matplotlib.pyplot
16+
- polygon's python-client library
17+
18+
You can likely run:
19+
pip install pandas numpy seaborn matplotlib polygon-api-client
20+
21+
2) API Key: You will need a Polygon API key to fetch the stock data. This can
22+
be set manually in the script below, or you can set an environment variable
23+
'POLYGON_API_KEY'.
24+
25+
setx POLYGON_API_KEY "<your_api_key>" <- windows
26+
export POLYGON_API_KEY="<your_api_key>" <- mac/linux
27+
28+
3) Select Stocks: You need to select the stocks you're interested in analyzing.
29+
Update the 'symbols' variable in this script with your chosen stock symbols.
30+
31+
4) Select Date Range: You need to specify the date range for the historical
32+
data that you want to fetch. Update the 'start_date' and 'end_date'
33+
variables in this script accordingly.
34+
35+
Understanding stock correlation is important when building a diverse portfolio,
36+
as it can help manage risk and inform investment strategies. It's always
37+
essential to do your own research or consult a financial advisor for
38+
personalized advice when investing.
39+
"""
40+
import pandas as pd # type: ignore
41+
import numpy as np # type: ignore
42+
import seaborn as sns # type: ignore
43+
import matplotlib.pyplot as plt # type: ignore
44+
from polygon import RESTClient
45+
46+
# Less likely to be correlated due to being in different sectors and are
47+
# exposed to different market forces, economic trends, and price risks.
48+
# symbols = ["TSLA", "PFE", "XOM", "HD", "JPM", "AAPL", "KO", "UNH", "LMT", "AMZN"]
49+
50+
# Here we have two groups, one with 5 technology stocks and another with 5 oil
51+
# stocks. These two groups are likely to be highly correlated within their
52+
# respective sectors but are expected to be less correlated between sectors.
53+
# symbols = ["AAPL", "MSFT", "GOOG", "ADBE", "CRM", "XOM", "CVX", "COP", "PSX", "OXY"]
54+
55+
# Likely to be highly correlated due to being in the technology sector,
56+
# specifically in the sub-industry of Semiconductors:
57+
symbols = ["INTC", "AMD", "NVDA", "TXN", "QCOM", "MU", "AVGO", "ADI", "MCHP", "NXPI"]
58+
59+
# Date range you are interested in
60+
start_date = "2022-04-01"
61+
end_date = "2023-05-10"
62+
63+
64+
def fetch_stock_data(symbols, start_date, end_date):
65+
stocks = []
66+
67+
# client = RESTClient("XXXXXX") # hardcoded api_key is used
68+
client = RESTClient() # POLYGON_API_KEY environment variable is used
69+
70+
try:
71+
for symbol in symbols:
72+
aggs = client.get_aggs(
73+
symbol,
74+
1,
75+
"day",
76+
start_date,
77+
end_date,
78+
)
79+
df = pd.DataFrame(aggs, columns=["timestamp", "close"])
80+
81+
# Filter out rows with invalid timestamps
82+
df = df[df["timestamp"] > 0]
83+
84+
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
85+
df.set_index("timestamp", inplace=True)
86+
87+
df.rename(columns={"close": symbol}, inplace=True)
88+
stocks.append(df)
89+
finally:
90+
pass
91+
92+
merged_stocks = pd.concat(stocks, axis=1)
93+
return merged_stocks
94+
95+
96+
def calculate_daily_returns(stock_data):
97+
daily_returns = stock_data.pct_change().dropna()
98+
return daily_returns
99+
100+
101+
def compute_correlation_matrix(daily_returns):
102+
correlation_matrix = daily_returns.corr()
103+
return correlation_matrix
104+
105+
106+
def plot_correlation_heatmap(correlation_matrix):
107+
plt.figure(figsize=(8, 8))
108+
ax = sns.heatmap(
109+
correlation_matrix,
110+
annot=True,
111+
cmap="coolwarm",
112+
vmin=-1,
113+
vmax=1,
114+
square=True,
115+
linewidths=0.5,
116+
cbar_kws={"shrink": 0.8},
117+
)
118+
ax.xaxis.tick_top()
119+
ax.xaxis.set_label_position("top")
120+
plt.title("Correlation Matrix Heatmap", y=1.08)
121+
plt.show()
122+
123+
124+
def main():
125+
stock_data = fetch_stock_data(symbols, start_date, end_date)
126+
daily_returns = calculate_daily_returns(stock_data)
127+
correlation_matrix = compute_correlation_matrix(daily_returns)
128+
129+
print("Correlation Matrix:")
130+
print(correlation_matrix)
131+
132+
plot_correlation_heatmap(correlation_matrix)
133+
134+
135+
if __name__ == "__main__":
136+
main()

0 commit comments

Comments
 (0)