Skip to content

Commit 3b73173

Browse files
convert original examples to pytests
1 parent 019d8bb commit 3b73173

File tree

7 files changed

+492
-20
lines changed

7 files changed

+492
-20
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""
2+
Show how to make date plots in matplotlib using date tick locators and
3+
formatters. See major_minor_demo1.py for more information on
4+
controlling major and minor ticks
5+
6+
All matplotlib date plotting is done by converting date instances into
7+
days since the 0001-01-01 UTC. The conversion, tick locating and
8+
formatting is done behind the scenes so this is most transparent to
9+
you. The dates module provides several converter functions date2num
10+
and num2date
11+
12+
This example requires an active internet connection since it uses
13+
yahoo finance to get the data for plotting
14+
"""
15+
16+
import matplotlib.pyplot as plt
17+
import pandas as pd
18+
from pandas.plotting import register_matplotlib_converters
19+
register_matplotlib_converters()
20+
from matplotlib.dates import DateFormatter, MonthLocator, YearLocator
21+
import os.path
22+
import io
23+
24+
def test_date_demo1():
25+
26+
years = YearLocator() # every year
27+
months = MonthLocator() # every month
28+
yearsFmt = DateFormatter('%Y')
29+
30+
# make file paths OS independent
31+
infile = os.path.join('examples','data','yahoofinance-INTC-19950101-20040412.csv')
32+
quotes = pd.read_csv(infile,index_col=0,parse_dates=True,infer_datetime_format=True)
33+
34+
dates = quotes.index
35+
opens = quotes['Open']
36+
37+
fig, ax = plt.subplots()
38+
ax.plot_date(dates, opens, '-')
39+
40+
# format the ticks
41+
ax.xaxis.set_major_locator(years)
42+
ax.xaxis.set_major_formatter(yearsFmt)
43+
ax.xaxis.set_minor_locator(months)
44+
ax.autoscale_view()
45+
46+
47+
# format the coords message box
48+
def price(x):
49+
return '$%1.2f' % x
50+
51+
52+
ax.fmt_xdata = DateFormatter('%Y-%m-%d')
53+
ax.fmt_ydata = price
54+
ax.grid(True)
55+
56+
fig.autofmt_xdate()
57+
buf = io.BytesIO()
58+
plt.savefig(buf)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
Show how to make date plots in matplotlib using date tick locators and
3+
formatters. See major_minor_demo1.py for more information on
4+
controlling major and minor ticks
5+
"""
6+
7+
import matplotlib.pyplot as plt
8+
import pandas as pd
9+
from pandas.plotting import register_matplotlib_converters
10+
register_matplotlib_converters()
11+
from matplotlib.dates import (MONDAY, DateFormatter, MonthLocator,
12+
WeekdayLocator)
13+
import os.path
14+
import io
15+
16+
def test_date_demo2():
17+
18+
date1 = "2002-1-5"
19+
date2 = "2003-12-1"
20+
21+
# every monday
22+
mondays = WeekdayLocator(MONDAY)
23+
24+
# every 3rd month
25+
months = MonthLocator(range(1, 13), bymonthday=1, interval=3)
26+
monthsFmt = DateFormatter("%b '%y")
27+
28+
29+
infile = os.path.join('examples','data','yahoofinance-INTC-19950101-20040412.csv')
30+
quotes = pd.read_csv(infile,
31+
index_col=0,
32+
parse_dates=True,
33+
infer_datetime_format=True)
34+
35+
# select desired range of dates
36+
quotes = quotes[(quotes.index >= date1) & (quotes.index <= date2)]
37+
38+
dates = quotes.index
39+
opens = quotes['Open']
40+
41+
42+
fig, ax = plt.subplots()
43+
ax.plot_date(dates, opens, '-')
44+
ax.xaxis.set_major_locator(months)
45+
ax.xaxis.set_major_formatter(monthsFmt)
46+
ax.xaxis.set_minor_locator(mondays)
47+
ax.autoscale_view()
48+
# ax.xaxis.grid(False, 'major')
49+
# ax.xaxis.grid(True, 'minor')
50+
ax.grid(True)
51+
52+
fig.autofmt_xdate()
53+
54+
buf = io.BytesIO()
55+
plt.savefig(buf)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import datetime
2+
3+
import matplotlib.dates as mdates
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
from pandas.plotting import register_matplotlib_converters
7+
register_matplotlib_converters()
8+
from matplotlib.dates import MONDAY, DateFormatter, DayLocator, WeekdayLocator
9+
import os.path
10+
import io
11+
12+
from mplfinance.original_flavor import candlestick_ohlc
13+
14+
def test_finance_demo():
15+
16+
date1 = "2004-2-1"
17+
date2 = "2004-4-12"
18+
19+
20+
mondays = WeekdayLocator(MONDAY) # major ticks on the mondays
21+
alldays = DayLocator() # minor ticks on the days
22+
weekFormatter = DateFormatter('%b %d') # e.g., Jan 12
23+
dayFormatter = DateFormatter('%d') # e.g., 12
24+
25+
infile = os.path.join('examples','data','yahoofinance-INTC-19950101-20040412.csv')
26+
quotes = pd.read_csv(infile,
27+
index_col=0,
28+
parse_dates=True,
29+
infer_datetime_format=True)
30+
31+
# select desired range of dates
32+
quotes = quotes[(quotes.index >= date1) & (quotes.index <= date2)]
33+
34+
fig, ax = plt.subplots()
35+
fig.subplots_adjust(bottom=0.2)
36+
ax.xaxis.set_major_locator(mondays)
37+
ax.xaxis.set_minor_locator(alldays)
38+
ax.xaxis.set_major_formatter(weekFormatter)
39+
# ax.xaxis.set_minor_formatter(dayFormatter)
40+
41+
# plot_day_summary(ax, quotes, ticksize=3)
42+
candlestick_ohlc(ax, zip(mdates.date2num(quotes.index.to_pydatetime()),
43+
quotes['Open'], quotes['High'],
44+
quotes['Low'], quotes['Close']),
45+
width=0.6)
46+
47+
ax.xaxis_date()
48+
ax.autoscale_view()
49+
plt.setp(plt.gca().get_xticklabels(), rotation=45, horizontalalignment='right')
50+
51+
buf = io.BytesIO()
52+
plt.savefig(buf)
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import matplotlib.dates as mdates
2+
import matplotlib.font_manager as font_manager
3+
import matplotlib.pyplot as plt
4+
import matplotlib.ticker as mticker
5+
import numpy as np
6+
import pandas as pd
7+
from pandas.plotting import register_matplotlib_converters
8+
register_matplotlib_converters()
9+
import os.path
10+
import io
11+
12+
def test_finance_work2():
13+
14+
ticker = 'SPY'
15+
infile = os.path.join('examples','data','yahoofinance-SPY-20080101-20180101.csv')
16+
r = pd.read_csv(infile,
17+
index_col=0,
18+
parse_dates=True,
19+
infer_datetime_format=True)
20+
21+
22+
def moving_average(x, n, type='simple'):
23+
"""
24+
compute an n period moving average.
25+
26+
type is 'simple' | 'exponential'
27+
28+
"""
29+
x = np.asarray(x)
30+
if type == 'simple':
31+
weights = np.ones(n)
32+
else:
33+
weights = np.exp(np.linspace(-1., 0., n))
34+
35+
weights /= weights.sum()
36+
37+
a = np.convolve(x, weights, mode='full')[:len(x)]
38+
a[:n] = a[n]
39+
return a
40+
41+
42+
def relative_strength(prices, n=14):
43+
"""
44+
compute the n period relative strength indicator
45+
http://stockcharts.com/school/doku.php?id=chart_school:glossary_r#relativestrengthindex
46+
http://www.investopedia.com/terms/r/rsi.asp
47+
"""
48+
49+
deltas = np.diff(prices)
50+
seed = deltas[:n + 1]
51+
up = seed[seed >= 0].sum() / n
52+
down = -seed[seed < 0].sum() / n
53+
rs = up / down
54+
rsi = np.zeros_like(prices)
55+
rsi[:n] = 100. - 100. / (1. + rs)
56+
57+
for i in range(n, len(prices)):
58+
delta = deltas[i - 1] # cause the diff is 1 shorter
59+
60+
if delta > 0:
61+
upval = delta
62+
downval = 0.
63+
else:
64+
upval = 0.
65+
downval = -delta
66+
67+
up = (up * (n - 1) + upval) / n
68+
down = (down * (n - 1) + downval) / n
69+
70+
rs = up / down
71+
rsi[i] = 100. - 100. / (1. + rs)
72+
73+
return rsi
74+
75+
76+
def moving_average_convergence(x, nslow=26, nfast=12):
77+
"""
78+
compute the MACD (Moving Average Convergence/Divergence) using a fast and
79+
slow exponential moving avg
80+
81+
return value is emaslow, emafast, macd which are len(x) arrays
82+
"""
83+
emaslow = moving_average(x, nslow, type='exponential')
84+
emafast = moving_average(x, nfast, type='exponential')
85+
return emaslow, emafast, emafast - emaslow
86+
87+
88+
plt.rc('axes', grid=True)
89+
plt.rc('grid', color='0.75', linestyle='-', linewidth=0.5)
90+
91+
textsize = 9
92+
left, width = 0.1, 0.8
93+
rect1 = [left, 0.7, width, 0.2]
94+
rect2 = [left, 0.3, width, 0.4]
95+
rect3 = [left, 0.1, width, 0.2]
96+
97+
98+
fig = plt.figure(facecolor='white')
99+
axescolor = '#f6f6f6' # the axes background color
100+
101+
ax1 = fig.add_axes(rect1, facecolor=axescolor) # left, bottom, width, height
102+
ax2 = fig.add_axes(rect2, facecolor=axescolor, sharex=ax1)
103+
ax2t = ax2.twinx()
104+
ax3 = fig.add_axes(rect3, facecolor=axescolor, sharex=ax1)
105+
106+
107+
# plot the relative strength indicator
108+
prices = r["Adj Close"]
109+
rsi = relative_strength(prices)
110+
fillcolor = 'darkgoldenrod'
111+
112+
ax1.plot(r.index, rsi, color=fillcolor)
113+
ax1.axhline(70, color=fillcolor)
114+
ax1.axhline(30, color=fillcolor)
115+
ax1.fill_between(r.index, rsi, 70, where=(rsi >= 70),
116+
facecolor=fillcolor, edgecolor=fillcolor)
117+
ax1.fill_between(r.index, rsi, 30, where=(rsi <= 30),
118+
facecolor=fillcolor, edgecolor=fillcolor)
119+
ax1.text(0.6, 0.9, '>70 = overbought', va='top',
120+
transform=ax1.transAxes, fontsize=textsize)
121+
ax1.text(0.6, 0.1, '<30 = oversold',
122+
transform=ax1.transAxes, fontsize=textsize)
123+
ax1.set_ylim(0, 100)
124+
ax1.set_yticks([30, 70])
125+
ax1.text(0.025, 0.95, 'RSI (14)', va='top',
126+
transform=ax1.transAxes, fontsize=textsize)
127+
ax1.set_title('%s daily' % ticker)
128+
129+
# plot the price and volume data
130+
dx = r["Adj Close"] - r.Close
131+
low = r.Low + dx
132+
high = r.High + dx
133+
134+
deltas = np.zeros_like(prices)
135+
deltas[1:] = np.diff(prices)
136+
up = deltas > 0
137+
ax2.vlines(r.index[up], low[up], high[up], color='black', label='_nolegend_')
138+
ax2.vlines(r.index[~up], low[~up], high[~up],
139+
color='black', label='_nolegend_')
140+
ma20 = moving_average(prices, 20, type='simple')
141+
ma200 = moving_average(prices, 200, type='simple')
142+
143+
linema20, = ax2.plot(r.index, ma20, color='blue', lw=2, label='MA (20)')
144+
linema200, = ax2.plot(r.index, ma200, color='red', lw=2, label='MA (200)')
145+
146+
last = r.tail(1)
147+
s = '%s O:%1.2f H:%1.2f L:%1.2f C:%1.2f, V:%1.1fM Chg:%+1.2f' % (
148+
last.index.strftime('%Y.%m.%d')[0],
149+
last.Open, last.High,
150+
last.Low, last.Close,
151+
last.Volume * 1e-6,
152+
last.Close - last.Open)
153+
t4 = ax2.text(0.3, 0.9, s, transform=ax2.transAxes, fontsize=textsize)
154+
155+
props = font_manager.FontProperties(size=10)
156+
leg = ax2.legend(loc='center left', shadow=True, fancybox=True, prop=props)
157+
leg.get_frame().set_alpha(0.5)
158+
159+
160+
volume = (r.Close * r.Volume) / 1e6 # dollar volume in millions
161+
vmax = volume.max()
162+
poly = ax2t.fill_between(r.index, volume, 0, label='Volume',
163+
facecolor=fillcolor, edgecolor=fillcolor)
164+
ax2t.set_ylim(0, 5 * vmax)
165+
ax2t.set_yticks([])
166+
167+
168+
# compute the MACD indicator
169+
fillcolor = 'darkslategrey'
170+
nslow = 26
171+
nfast = 12
172+
nema = 9
173+
emaslow, emafast, macd = moving_average_convergence(
174+
prices, nslow=nslow, nfast=nfast)
175+
ema9 = moving_average(macd, nema, type='exponential')
176+
ax3.plot(r.index, macd, color='black', lw=2)
177+
ax3.plot(r.index, ema9, color='blue', lw=1)
178+
ax3.fill_between(r.index, macd - ema9, 0, alpha=0.5,
179+
facecolor=fillcolor, edgecolor=fillcolor)
180+
181+
182+
ax3.text(0.025, 0.95, 'MACD (%d, %d, %d)' % (nfast, nslow, nema), va='top',
183+
transform=ax3.transAxes, fontsize=textsize)
184+
185+
# ax3.set_yticks([])
186+
# turn off upper axis tick labels, rotate the lower ones, etc
187+
for ax in ax1, ax2, ax2t, ax3:
188+
if ax != ax3:
189+
for label in ax.get_xticklabels():
190+
label.set_visible(False)
191+
else:
192+
for label in ax.get_xticklabels():
193+
label.set_rotation(30)
194+
label.set_horizontalalignment('right')
195+
196+
ax.fmt_xdata = mdates.DateFormatter('%Y-%m-%d')
197+
198+
199+
class MyLocator(mticker.MaxNLocator):
200+
def __init__(self, *args, **kwargs):
201+
mticker.MaxNLocator.__init__(self, *args, **kwargs)
202+
203+
def __call__(self, *args, **kwargs):
204+
return mticker.MaxNLocator.__call__(self, *args, **kwargs)
205+
206+
# at most 5 ticks, pruning the upper and lower so they don't overlap
207+
# with other ticks
208+
# ax2.yaxis.set_major_locator(mticker.MaxNLocator(5, prune='both'))
209+
# ax3.yaxis.set_major_locator(mticker.MaxNLocator(5, prune='both'))
210+
211+
212+
ax2.yaxis.set_major_locator(MyLocator(5, prune='both'))
213+
ax3.yaxis.set_major_locator(MyLocator(5, prune='both'))
214+
215+
buf = io.BytesIO()
216+
plt.savefig(buf)

0 commit comments

Comments
 (0)