Skip to content

Commit 9aae2a7

Browse files
author
Gianluca
committed
create separate file for plotting
1 parent 4dde05b commit 9aae2a7

File tree

2 files changed

+244
-157
lines changed

2 files changed

+244
-157
lines changed

plotting.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import os
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
5+
def compute_uncertainty_bounds(est: np.array, std: np.array):
6+
return np.maximum(0, est - 2 * std), est + 2 * std
7+
8+
def plot_market_estimates(data: dict, est: np.array, std: np.array):
9+
"""
10+
It makes a market estimation plot with prices, trends, uncertainties and volumes.
11+
12+
Parameters
13+
----------
14+
data: dict
15+
Downloaded data.
16+
est: np.array
17+
Price trend estimate at market-level.
18+
std: np.array
19+
Standard deviation estimate of price trend at market-level.
20+
"""
21+
print('\nPlotting market estimation...')
22+
fig = plt.figure(figsize=(10, 3))
23+
logp = np.log(data['price'])
24+
t = logp.shape[1]
25+
lb, ub = compute_uncertainty_bounds(est, std)
26+
27+
plt.grid(axis='both')
28+
plt.title("Market", fontsize=15)
29+
avg_price = np.exp(logp.mean(0))
30+
l1 = plt.plot(data["dates"], avg_price, label="avg. price in {}".format(data['default_currency']), color="C0")
31+
l2 = plt.plot(data["dates"], est[0], label="trend", color="C1")
32+
l3 = plt.fill_between(data["dates"], lb[0], ub[0], alpha=0.2, label="+/- 2 st. dev.", color="C0")
33+
plt.ylabel("avg. price in {}".format(data['default_currency']), fontsize=12)
34+
plt.twinx()
35+
l4 = plt.bar(data["dates"], data['volume'].mean(0), width=1, color='g', alpha=0.2, label='avg. volume')
36+
l4[0].set_edgecolor('r')
37+
for d in range(1, t):
38+
if avg_price[d] - avg_price[d - 1] < 0:
39+
l4[d].set_color('r')
40+
plt.ylabel("avg. volume", fontsize=12)
41+
ll = l1 + l2 + [l3] + [l4]
42+
labels = [l.get_label() for l in ll]
43+
plt.legend(ll, labels, loc="upper left")
44+
fig_name = 'market_estimation.png'
45+
fig.savefig(fig_name, dpi=fig.dpi)
46+
print('Market estimation plot has been saved to {}/{}.'.format(os.getcwd(), fig_name))
47+
48+
def plot_sector_estimates(data: dict, info: dict, est: np.array, std: np.array):
49+
"""
50+
It makes a plot for each sector with prices, trends, uncertainties and volumes.
51+
52+
Parameters
53+
----------
54+
data: dict
55+
Downloaded data.
56+
info: dict
57+
Model hierarchy information.
58+
est: np.array
59+
Price trend estimate at sector-level.
60+
std: np.array
61+
Standard deviation estimate of price trend at sector-level.
62+
"""
63+
print('\nPlotting sector estimation...')
64+
num_columns = 3
65+
logp = np.log(data['price'])
66+
t = logp.shape[1]
67+
lb, ub = compute_uncertainty_bounds(est, std)
68+
69+
NA_sectors = np.where(np.array([sec[:2] for sec in info['unique_sectors']]) == "NA")[0]
70+
num_NA_sectors = len(NA_sectors)
71+
72+
fig = plt.figure(figsize=(20, max(info['num_sectors'] - num_NA_sectors, 5)))
73+
j = 0
74+
for i in range(info['num_sectors']):
75+
if i not in NA_sectors:
76+
j += 1
77+
plt.subplot(int(np.ceil((info['num_sectors'] - num_NA_sectors) / num_columns)), num_columns, j)
78+
plt.grid(axis='both')
79+
plt.title(info['unique_sectors'][i], fontsize=15)
80+
idx_sectors = np.where(np.array(info['sectors_id']) == i)[0]
81+
avg_price = np.exp(logp[idx_sectors].reshape(-1, t).mean(0))
82+
l1 = plt.plot(data["dates"], avg_price,
83+
label="avg. price in {}".format(data['default_currency']), color="C0")
84+
l2 = plt.plot(data["dates"], est[i], label="trend", color="C1")
85+
l3 = plt.fill_between(data["dates"], lb[i], ub[i], alpha=0.2, label="+/- 2 st. dev.",
86+
color="C0")
87+
plt.ylabel("avg. price in {}".format(data['default_currency']), fontsize=12)
88+
plt.xticks(rotation=45)
89+
plt.twinx()
90+
l4 = plt.bar(data["dates"],
91+
data['volume'][np.where(np.array(info['sectors_id']) == i)[0]].reshape(-1, t).mean(0),
92+
width=1, color='g', alpha=0.2, label='avg. volume')
93+
for d in range(1, t):
94+
if avg_price[d] - avg_price[d - 1] < 0:
95+
l4[d].set_color('r')
96+
l4[0].set_edgecolor('r')
97+
plt.ylabel("avg. volume", fontsize=12)
98+
ll = l1 + l2 + [l3] + [l4]
99+
labels = [l.get_label() for l in ll]
100+
plt.legend(ll, labels, loc="upper left")
101+
102+
plt.tight_layout()
103+
fig_name = 'sector_estimation.png'
104+
fig.savefig(fig_name, dpi=fig.dpi)
105+
print('Sector estimation plot has been saved to {}/{}.'.format(os.getcwd(), fig_name))
106+
107+
def plot_industry_estimates(data: dict, info: dict, est: np.array, std: np.array):
108+
"""
109+
It makes a plot for each industry with prices, trends, uncertainties and volumes.
110+
111+
Parameters
112+
----------
113+
data: dict
114+
Downloaded data.
115+
info: dict
116+
Model hierarchy information.
117+
est: np.array
118+
Price trend estimate at industry-level.
119+
std: np.array
120+
Standard deviation estimate of price trend at industry-level.
121+
"""
122+
print('\nPlotting industry estimation...')
123+
num_columns = 3
124+
logp = np.log(data['price'])
125+
t = logp.shape[1]
126+
lb, ub = compute_uncertainty_bounds(est, std)
127+
128+
NA_industries = np.where(np.array([ind[:2] for ind in info['unique_industries']]) == "NA")[0]
129+
num_NA_industries = len(NA_industries)
130+
131+
fig = plt.figure(figsize=(20, max(info['num_industries'] - num_NA_industries, 5)))
132+
j = 0
133+
for i in range(info['num_industries']):
134+
if i not in NA_industries:
135+
j += 1
136+
plt.subplot(int(np.ceil((info['num_industries'] - num_NA_industries) / num_columns)), num_columns, j)
137+
plt.grid(axis='both')
138+
plt.title(info['unique_industries'][i], fontsize=15)
139+
idx_industries = np.where(np.array(info['industries_id']) == i)[0]
140+
plt.title(info['unique_industries'][i], fontsize=15)
141+
avg_price = np.exp(logp[idx_industries].reshape(-1, t).mean(0))
142+
l1 = plt.plot(data["dates"], avg_price,
143+
label="avg. price in {}".format(data['default_currency']), color="C0")
144+
l2 = plt.plot(data["dates"], est[i], label="trend", color="C1")
145+
l3 = plt.fill_between(data["dates"], lb[i], ub[i], alpha=0.2, label="+/- 2 st. dev.",
146+
color="C0")
147+
plt.ylabel("avg. price in {}".format(data['default_currency']), fontsize=12)
148+
plt.xticks(rotation=45)
149+
plt.twinx()
150+
l4 = plt.bar(data["dates"],
151+
data['volume'][np.where(np.array(info['industries_id']) == i)[0]].reshape(-1, t).mean(0),
152+
width=1, color='g', alpha=0.2, label='avg. volume')
153+
for d in range(1, t):
154+
if avg_price[d] - avg_price[d - 1] < 0:
155+
l4[d].set_color('r')
156+
l4[0].set_edgecolor('r')
157+
plt.ylabel("avg. volume", fontsize=12)
158+
ll = l1 + l2 + [l3] + [l4]
159+
labels = [l.get_label() for l in ll]
160+
plt.legend(ll, labels, loc="upper left")
161+
plt.tight_layout()
162+
fig_name = 'industry_estimation.png'
163+
fig.savefig(fig_name, dpi=fig.dpi)
164+
print('Industry estimation plot has been saved to {}/{}.'.format(os.getcwd(), fig_name))
165+
166+
def plot_stock_estimates(data: dict, est: np.array, std: np.array, rank_type: str, rank: list, ranked_rates: np.array):
167+
"""
168+
It makes a plot for each stock with prices, trends, uncertainties and volumes.
169+
170+
Parameters
171+
----------
172+
data: dict
173+
Downloaded data.
174+
est: np.array
175+
Price trend estimate at stock-level.
176+
std: np.array
177+
Standard deviation estimate of price trend at stock-level.
178+
rank_type: str
179+
Type of rank. It can be either `rate` or `growth`.
180+
rank: list
181+
List of integers at stock-level indicating the rank specified in `rank_type`.
182+
ranked_rates: np.array
183+
Array of rates at stock-level ranked according to `rank`.
184+
"""
185+
num_stocks, t = data['price'].shape
186+
187+
# determine which stocks are along trend to avoid plotting them
188+
if rank_type == "rate":
189+
to_plot = np.where(np.array(ranked_rates) != "ALONG TREND")[0]
190+
else:
191+
to_plot = np.where(np.array(ranked_rates) == "ALONG TREND")[0][:99]
192+
dont_plot = [x for x in np.arange(num_stocks) if x not in to_plot]
193+
num_to_plot = len(to_plot)
194+
if num_to_plot > 0:
195+
print('\nPlotting stock estimation...')
196+
num_columns = 3
197+
198+
ranked_tickers = np.array(data['tickers'])[rank]
199+
ranked_p = data['price'][rank]
200+
ranked_volume = data['volume'][rank]
201+
ranked_currencies = np.array(data['currencies'])[rank]
202+
ranked_est = est[rank]
203+
ranked_std = std[rank]
204+
205+
ranked_lb, ranked_ub = compute_uncertainty_bounds(ranked_est, ranked_std)
206+
207+
j = 0
208+
fig = plt.figure(figsize=(20, max(num_to_plot, 5)))
209+
for i in range(num_stocks):
210+
if i not in dont_plot:
211+
j += 1
212+
plt.subplot(int(np.ceil(num_to_plot / num_columns)), num_columns, j)
213+
plt.grid(axis='both')
214+
plt.title(ranked_tickers[i], fontsize=15)
215+
l1 = plt.plot(data["dates"], ranked_p[i], label="price in {}".format(ranked_currencies[i]))
216+
l2 = plt.plot(data["dates"], ranked_est[i], label="trend")
217+
l3 = plt.fill_between(data["dates"], ranked_lb[i], ranked_ub[i], alpha=0.2,
218+
label="+/- 2 st. dev.")
219+
plt.yticks(fontsize=12)
220+
plt.xticks(rotation=45)
221+
plt.ylabel("price in {}".format(ranked_currencies[i]), fontsize=12)
222+
plt.twinx()
223+
l4 = plt.bar(data["dates"], ranked_volume[i], width=1, color='g', alpha=0.2, label='volume')
224+
for d in range(1, t):
225+
if ranked_p[i, d] - ranked_p[i, d - 1] < 0:
226+
l4[d].set_color('r')
227+
l4[0].set_edgecolor('r')
228+
plt.ylabel("volume", fontsize=12)
229+
ll = l1 + l2 + [l3] + [l4]
230+
labels = [l.get_label() for l in ll]
231+
plt.legend(ll, labels, loc="upper left")
232+
plt.tight_layout()
233+
fig_name = 'stock_estimation.png'
234+
fig.savefig(fig_name, dpi=fig.dpi)
235+
print('Stock estimation plot has been saved to {}/{}.'.format(os.getcwd(), fig_name))
236+
237+
elif os.path.exists('stock_estimation.png'):
238+
os.remove('stock_estimation.png')

0 commit comments

Comments
 (0)