Skip to content

Commit a10a5e2

Browse files
Merge pull request #65 from DanielGoldfarb/master
Return Figure and Axes (plus some minor changes)
2 parents 493055a + 431b994 commit a10a5e2

File tree

6 files changed

+23
-13
lines changed

6 files changed

+23
-13
lines changed

.github/ISSUE_TEMPLATE/feature_request.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
name: Feature request
33
about: Suggest an idea for this project
4-
title: 'Feature Reuest:'
4+
title: 'Feature Request:'
55
labels: 'enhancement'
66
assignees: ''
77

src/mplfinance/plotting.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ def _valid_plot_kwargs():
125125
'block' : { 'Default' : True,
126126
'Validator' : lambda value: isinstance(value,bool) },
127127

128+
'returnfig' : { 'Default' : False,
129+
'Validator' : lambda value: isinstance(value,bool) },
130+
128131
}
129132

130133
_validate_vkwargs_dict(vkwargs)
@@ -476,22 +479,29 @@ def plot( data, **kwargs ):
476479
if not used_ax4 and ax4 is not None:
477480
ax4.get_yaxis().set_visible(False)
478481

482+
if config['returnfig']:
483+
axlist = [ax1, ax3]
484+
if ax2: axlist.append(ax2)
485+
if ax4: axlist.append(ax4)
486+
479487
if config['savefig'] is not None:
480488
save = config['savefig']
481489
if isinstance(save,dict):
482490
plt.savefig(**save)
483491
else:
484492
plt.savefig(save)
485-
else:
493+
elif not config['returnfig']:
486494
# https://stackoverflow.com/a/13361748/1639359 suggests plt.show(block=False)
487495
plt.show(block=config['block'])
496+
497+
if config['returnfig']:
498+
return (fig, axlist)
488499

489500
# rcp = copy.deepcopy(plt.rcParams)
490501
# rcpdf = rcParams_to_df(rcp)
491502
# print('type(rcpdf)=',type(rcpdf))
492503
# print('rcpdfhead(3)=',rcpdf.head(3))
493504
# return # rcpdf
494-
495505

496506

497507
def _valid_addplot_kwargs():

tests/original_flavor/test_date_demo1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_date_demo1():
3131
infile = os.path.join('examples','data','yahoofinance-INTC-19950101-20040412.csv')
3232
quotes = pd.read_csv(infile,index_col=0,parse_dates=True,infer_datetime_format=True)
3333

34-
dates = quotes.index
34+
dates = quotes.index.values
3535
opens = quotes['Open']
3636

3737
fig, ax = plt.subplots()

tests/original_flavor/test_date_demo2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_date_demo2():
3535
# select desired range of dates
3636
quotes = quotes[(quotes.index >= date1) & (quotes.index <= date2)]
3737

38-
dates = quotes.index
38+
dates = quotes.index.values
3939
opens = quotes['Open']
4040

4141

tests/original_flavor/test_finance_work2.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ def moving_average_convergence(x, nslow=26, nfast=12):
109109
rsi = relative_strength(prices)
110110
fillcolor = 'darkgoldenrod'
111111

112-
ax1.plot(r.index, rsi, color=fillcolor)
112+
ax1.plot(r.index.values, rsi, color=fillcolor)
113113
ax1.axhline(70, color=fillcolor)
114114
ax1.axhline(30, color=fillcolor)
115-
ax1.fill_between(r.index, rsi, 70, where=(rsi >= 70),
115+
ax1.fill_between(r.index.values, rsi, 70, where=(rsi >= 70),
116116
facecolor=fillcolor, edgecolor=fillcolor)
117-
ax1.fill_between(r.index, rsi, 30, where=(rsi <= 30),
117+
ax1.fill_between(r.index.values, rsi, 30, where=(rsi <= 30),
118118
facecolor=fillcolor, edgecolor=fillcolor)
119119
ax1.text(0.6, 0.9, '>70 = overbought', va='top',
120120
transform=ax1.transAxes, fontsize=textsize)
@@ -140,8 +140,8 @@ def moving_average_convergence(x, nslow=26, nfast=12):
140140
ma20 = moving_average(prices, 20, type='simple')
141141
ma200 = moving_average(prices, 200, type='simple')
142142

143-
linema20, = ax2.plot(r.index, ma20, color='blue', lw=2, label='MA (20)')
144-
linema200, = ax2.plot(r.index, ma200, color='red', lw=2, label='MA (200)')
143+
linema20, = ax2.plot(r.index.values, ma20, color='blue', lw=2, label='MA (20)')
144+
linema200, = ax2.plot(r.index.values, ma200, color='red', lw=2, label='MA (200)')
145145

146146
last = r.tail(1)
147147
s = '%s O:%1.2f H:%1.2f L:%1.2f C:%1.2f, V:%1.1fM Chg:%+1.2f' % (
@@ -173,8 +173,8 @@ def moving_average_convergence(x, nslow=26, nfast=12):
173173
emaslow, emafast, macd = moving_average_convergence(
174174
prices, nslow=nslow, nfast=nfast)
175175
ema9 = moving_average(macd, nema, type='exponential')
176-
ax3.plot(r.index, macd, color='black', lw=2)
177-
ax3.plot(r.index, ema9, color='blue', lw=1)
176+
ax3.plot(r.index.values, macd, color='black', lw=2)
177+
ax3.plot(r.index.values, ema9, color='blue', lw=1)
178178
ax3.fill_between(r.index, macd - ema9, 0, alpha=0.5,
179179
facecolor=fillcolor, edgecolor=fillcolor)
180180

tests/original_flavor/test_longshort.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_longshort():
5454

5555
# plot the return
5656
fig, ax = plt.subplots()
57-
ax.plot(r.index, tr)
57+
ax.plot(r.index.values, tr)
5858
ax.set_title('total return: long APPL, short GOOG')
5959
ax.grid()
6060
fig.autofmt_xdate()

0 commit comments

Comments
 (0)