1+ import os
2+ import numpy as np
3+ import matplotlib .pyplot as plt
4+
5+ def compute_uncertainty_bounds (est : np .array , std : np .array ):
6+ return np .maximum (0 , est - 2 * std ), est + 2 * std
7+
8+ def plot_market_estimates (data : dict , est : np .array , std : np .array ):
9+ """
10+ It makes a market estimation plot with prices, trends, uncertainties and volumes.
11+
12+ Parameters
13+ ----------
14+ data: dict
15+ Downloaded data.
16+ est: np.array
17+ Price trend estimate at market-level.
18+ std: np.array
19+ Standard deviation estimate of price trend at market-level.
20+ """
21+ print ('\n Plotting market estimation...' )
22+ fig = plt .figure (figsize = (10 , 3 ))
23+ logp = np .log (data ['price' ])
24+ t = logp .shape [1 ]
25+ lb , ub = compute_uncertainty_bounds (est , std )
26+
27+ plt .grid (axis = 'both' )
28+ plt .title ("Market" , fontsize = 15 )
29+ avg_price = np .exp (logp .mean (0 ))
30+ l1 = plt .plot (data ["dates" ], avg_price , label = "avg. price in {}" .format (data ['default_currency' ]), color = "C0" )
31+ l2 = plt .plot (data ["dates" ], est [0 ], label = "trend" , color = "C1" )
32+ l3 = plt .fill_between (data ["dates" ], lb [0 ], ub [0 ], alpha = 0.2 , label = "+/- 2 st. dev." , color = "C0" )
33+ plt .ylabel ("avg. price in {}" .format (data ['default_currency' ]), fontsize = 12 )
34+ plt .twinx ()
35+ l4 = plt .bar (data ["dates" ], data ['volume' ].mean (0 ), width = 1 , color = 'g' , alpha = 0.2 , label = 'avg. volume' )
36+ l4 [0 ].set_edgecolor ('r' )
37+ for d in range (1 , t ):
38+ if avg_price [d ] - avg_price [d - 1 ] < 0 :
39+ l4 [d ].set_color ('r' )
40+ plt .ylabel ("avg. volume" , fontsize = 12 )
41+ ll = l1 + l2 + [l3 ] + [l4 ]
42+ labels = [l .get_label () for l in ll ]
43+ plt .legend (ll , labels , loc = "upper left" )
44+ fig_name = 'market_estimation.png'
45+ fig .savefig (fig_name , dpi = fig .dpi )
46+ print ('Market estimation plot has been saved to {}/{}.' .format (os .getcwd (), fig_name ))
47+
48+ def plot_sector_estimates (data : dict , info : dict , est : np .array , std : np .array ):
49+ """
50+ It makes a plot for each sector with prices, trends, uncertainties and volumes.
51+
52+ Parameters
53+ ----------
54+ data: dict
55+ Downloaded data.
56+ info: dict
57+ Model hierarchy information.
58+ est: np.array
59+ Price trend estimate at sector-level.
60+ std: np.array
61+ Standard deviation estimate of price trend at sector-level.
62+ """
63+ print ('\n Plotting sector estimation...' )
64+ num_columns = 3
65+ logp = np .log (data ['price' ])
66+ t = logp .shape [1 ]
67+ lb , ub = compute_uncertainty_bounds (est , std )
68+
69+ NA_sectors = np .where (np .array ([sec [:2 ] for sec in info ['unique_sectors' ]]) == "NA" )[0 ]
70+ num_NA_sectors = len (NA_sectors )
71+
72+ fig = plt .figure (figsize = (20 , max (info ['num_sectors' ] - num_NA_sectors , 5 )))
73+ j = 0
74+ for i in range (info ['num_sectors' ]):
75+ if i not in NA_sectors :
76+ j += 1
77+ plt .subplot (int (np .ceil ((info ['num_sectors' ] - num_NA_sectors ) / num_columns )), num_columns , j )
78+ plt .grid (axis = 'both' )
79+ plt .title (info ['unique_sectors' ][i ], fontsize = 15 )
80+ idx_sectors = np .where (np .array (info ['sectors_id' ]) == i )[0 ]
81+ avg_price = np .exp (logp [idx_sectors ].reshape (- 1 , t ).mean (0 ))
82+ l1 = plt .plot (data ["dates" ], avg_price ,
83+ label = "avg. price in {}" .format (data ['default_currency' ]), color = "C0" )
84+ l2 = plt .plot (data ["dates" ], est [i ], label = "trend" , color = "C1" )
85+ l3 = plt .fill_between (data ["dates" ], lb [i ], ub [i ], alpha = 0.2 , label = "+/- 2 st. dev." ,
86+ color = "C0" )
87+ plt .ylabel ("avg. price in {}" .format (data ['default_currency' ]), fontsize = 12 )
88+ plt .xticks (rotation = 45 )
89+ plt .twinx ()
90+ l4 = plt .bar (data ["dates" ],
91+ data ['volume' ][np .where (np .array (info ['sectors_id' ]) == i )[0 ]].reshape (- 1 , t ).mean (0 ),
92+ width = 1 , color = 'g' , alpha = 0.2 , label = 'avg. volume' )
93+ for d in range (1 , t ):
94+ if avg_price [d ] - avg_price [d - 1 ] < 0 :
95+ l4 [d ].set_color ('r' )
96+ l4 [0 ].set_edgecolor ('r' )
97+ plt .ylabel ("avg. volume" , fontsize = 12 )
98+ ll = l1 + l2 + [l3 ] + [l4 ]
99+ labels = [l .get_label () for l in ll ]
100+ plt .legend (ll , labels , loc = "upper left" )
101+
102+ plt .tight_layout ()
103+ fig_name = 'sector_estimation.png'
104+ fig .savefig (fig_name , dpi = fig .dpi )
105+ print ('Sector estimation plot has been saved to {}/{}.' .format (os .getcwd (), fig_name ))
106+
107+ def plot_industry_estimates (data : dict , info : dict , est : np .array , std : np .array ):
108+ """
109+ It makes a plot for each industry with prices, trends, uncertainties and volumes.
110+
111+ Parameters
112+ ----------
113+ data: dict
114+ Downloaded data.
115+ info: dict
116+ Model hierarchy information.
117+ est: np.array
118+ Price trend estimate at industry-level.
119+ std: np.array
120+ Standard deviation estimate of price trend at industry-level.
121+ """
122+ print ('\n Plotting industry estimation...' )
123+ num_columns = 3
124+ logp = np .log (data ['price' ])
125+ t = logp .shape [1 ]
126+ lb , ub = compute_uncertainty_bounds (est , std )
127+
128+ NA_industries = np .where (np .array ([ind [:2 ] for ind in info ['unique_industries' ]]) == "NA" )[0 ]
129+ num_NA_industries = len (NA_industries )
130+
131+ fig = plt .figure (figsize = (20 , max (info ['num_industries' ] - num_NA_industries , 5 )))
132+ j = 0
133+ for i in range (info ['num_industries' ]):
134+ if i not in NA_industries :
135+ j += 1
136+ plt .subplot (int (np .ceil ((info ['num_industries' ] - num_NA_industries ) / num_columns )), num_columns , j )
137+ plt .grid (axis = 'both' )
138+ plt .title (info ['unique_industries' ][i ], fontsize = 15 )
139+ idx_industries = np .where (np .array (info ['industries_id' ]) == i )[0 ]
140+ plt .title (info ['unique_industries' ][i ], fontsize = 15 )
141+ avg_price = np .exp (logp [idx_industries ].reshape (- 1 , t ).mean (0 ))
142+ l1 = plt .plot (data ["dates" ], avg_price ,
143+ label = "avg. price in {}" .format (data ['default_currency' ]), color = "C0" )
144+ l2 = plt .plot (data ["dates" ], est [i ], label = "trend" , color = "C1" )
145+ l3 = plt .fill_between (data ["dates" ], lb [i ], ub [i ], alpha = 0.2 , label = "+/- 2 st. dev." ,
146+ color = "C0" )
147+ plt .ylabel ("avg. price in {}" .format (data ['default_currency' ]), fontsize = 12 )
148+ plt .xticks (rotation = 45 )
149+ plt .twinx ()
150+ l4 = plt .bar (data ["dates" ],
151+ data ['volume' ][np .where (np .array (info ['industries_id' ]) == i )[0 ]].reshape (- 1 , t ).mean (0 ),
152+ width = 1 , color = 'g' , alpha = 0.2 , label = 'avg. volume' )
153+ for d in range (1 , t ):
154+ if avg_price [d ] - avg_price [d - 1 ] < 0 :
155+ l4 [d ].set_color ('r' )
156+ l4 [0 ].set_edgecolor ('r' )
157+ plt .ylabel ("avg. volume" , fontsize = 12 )
158+ ll = l1 + l2 + [l3 ] + [l4 ]
159+ labels = [l .get_label () for l in ll ]
160+ plt .legend (ll , labels , loc = "upper left" )
161+ plt .tight_layout ()
162+ fig_name = 'industry_estimation.png'
163+ fig .savefig (fig_name , dpi = fig .dpi )
164+ print ('Industry estimation plot has been saved to {}/{}.' .format (os .getcwd (), fig_name ))
165+
166+ def plot_stock_estimates (data : dict , est : np .array , std : np .array , rank_type : str , rank : list , ranked_rates : np .array ):
167+ """
168+ It makes a plot for each stock with prices, trends, uncertainties and volumes.
169+
170+ Parameters
171+ ----------
172+ data: dict
173+ Downloaded data.
174+ est: np.array
175+ Price trend estimate at stock-level.
176+ std: np.array
177+ Standard deviation estimate of price trend at stock-level.
178+ rank_type: str
179+ Type of rank. It can be either `rate` or `growth`.
180+ rank: list
181+ List of integers at stock-level indicating the rank specified in `rank_type`.
182+ ranked_rates: np.array
183+ Array of rates at stock-level ranked according to `rank`.
184+ """
185+ num_stocks , t = data ['price' ].shape
186+
187+ # determine which stocks are along trend to avoid plotting them
188+ if rank_type == "rate" :
189+ to_plot = np .where (np .array (ranked_rates ) != "ALONG TREND" )[0 ]
190+ else :
191+ to_plot = np .where (np .array (ranked_rates ) == "ALONG TREND" )[0 ][:99 ]
192+ dont_plot = [x for x in np .arange (num_stocks ) if x not in to_plot ]
193+ num_to_plot = len (to_plot )
194+ if num_to_plot > 0 :
195+ print ('\n Plotting stock estimation...' )
196+ num_columns = 3
197+
198+ ranked_tickers = np .array (data ['tickers' ])[rank ]
199+ ranked_p = data ['price' ][rank ]
200+ ranked_volume = data ['volume' ][rank ]
201+ ranked_currencies = np .array (data ['currencies' ])[rank ]
202+ ranked_est = est [rank ]
203+ ranked_std = std [rank ]
204+
205+ ranked_lb , ranked_ub = compute_uncertainty_bounds (ranked_est , ranked_std )
206+
207+ j = 0
208+ fig = plt .figure (figsize = (20 , max (num_to_plot , 5 )))
209+ for i in range (num_stocks ):
210+ if i not in dont_plot :
211+ j += 1
212+ plt .subplot (int (np .ceil (num_to_plot / num_columns )), num_columns , j )
213+ plt .grid (axis = 'both' )
214+ plt .title (ranked_tickers [i ], fontsize = 15 )
215+ l1 = plt .plot (data ["dates" ], ranked_p [i ], label = "price in {}" .format (ranked_currencies [i ]))
216+ l2 = plt .plot (data ["dates" ], ranked_est [i ], label = "trend" )
217+ l3 = plt .fill_between (data ["dates" ], ranked_lb [i ], ranked_ub [i ], alpha = 0.2 ,
218+ label = "+/- 2 st. dev." )
219+ plt .yticks (fontsize = 12 )
220+ plt .xticks (rotation = 45 )
221+ plt .ylabel ("price in {}" .format (ranked_currencies [i ]), fontsize = 12 )
222+ plt .twinx ()
223+ l4 = plt .bar (data ["dates" ], ranked_volume [i ], width = 1 , color = 'g' , alpha = 0.2 , label = 'volume' )
224+ for d in range (1 , t ):
225+ if ranked_p [i , d ] - ranked_p [i , d - 1 ] < 0 :
226+ l4 [d ].set_color ('r' )
227+ l4 [0 ].set_edgecolor ('r' )
228+ plt .ylabel ("volume" , fontsize = 12 )
229+ ll = l1 + l2 + [l3 ] + [l4 ]
230+ labels = [l .get_label () for l in ll ]
231+ plt .legend (ll , labels , loc = "upper left" )
232+ plt .tight_layout ()
233+ fig_name = 'stock_estimation.png'
234+ fig .savefig (fig_name , dpi = fig .dpi )
235+ print ('Stock estimation plot has been saved to {}/{}.' .format (os .getcwd (), fig_name ))
236+
237+ elif os .path .exists ('stock_estimation.png' ):
238+ os .remove ('stock_estimation.png' )
0 commit comments