Skip to content

Commit 4b9c8c0

Browse files
committed
Allow addplot, volume, and edgecolor for renko
1 parent b159292 commit 4b9c8c0

File tree

2 files changed

+66
-34
lines changed

2 files changed

+66
-34
lines changed

src/mplfinance/_utils.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def roundTime(dt=None, roundTo=60):
8282
rounding = (seconds+roundTo/2) // roundTo * roundTo
8383
return dt + datetime.timedelta(0,rounding-seconds,-dt.microsecond)
8484

85-
def calculate_atr(atr_length, highs, lows, closes):
85+
def _calculate_atr(atr_length, highs, lows, closes):
8686
"""Calculate the average true range
8787
atr_length : time period to calculate over
8888
all_highs : list of highs
@@ -98,6 +98,30 @@ def calculate_atr(atr_length, highs, lows, closes):
9898
atr += tr
9999
return atr/atr_length
100100

101+
def renko_reformat_ydata(ydata, dates, old_dates):
102+
"""Reformats ydata to work on renko charts, can lead to unexpected
103+
outputs for the user as the xaxis does not scale evenly with dates.
104+
Missing dates ydata is averaged into the next date and dates that appear
105+
more than once have the same ydata
106+
ydata : y data likely coming from addplot
107+
dates : x-axis dates for the renko chart
108+
old_dates : original dates in the data set
109+
"""
110+
new_ydata = [] # stores new ydata
111+
prev_data = 0
112+
skipped_dates = 0
113+
count_skip = 0
114+
for i in range(len(ydata)):
115+
if old_dates[i] not in dates:
116+
prev_data += ydata[i]
117+
skipped_dates += 1
118+
else:
119+
dup_dates = dates.count(old_dates[i])
120+
new_ydata.extend([(ydata[i]+prev_data)/(skipped_dates+1)]*dup_dates)
121+
skipped_dates = 0
122+
prev_data = 0
123+
return new_ydata
124+
101125
def _updown_colors(upcolor,downcolor,opens,closes,use_prev_close=False):
102126
if upcolor == downcolor:
103127
return upcolor
@@ -279,7 +303,7 @@ def _construct_candlestick_collections(dates, opens, highs, lows, closes, market
279303

280304
return rangeCollection, barCollection
281305

282-
def _construct_renko_collections(dates, highs, lows, renko_params, closes, marketcolors=None):
306+
def _construct_renko_collections(dates, highs, lows, volumes, renko_params, closes, marketcolors=None):
283307
"""Represent the price change with bricks
284308
285309
Parameters
@@ -314,28 +338,36 @@ def _construct_renko_collections(dates, highs, lows, renko_params, closes, marke
314338
raise ValueError("Specified atr_length is larger than the length of the dataset: " + str(len(closes)))
315339

316340
if brick_size == 'atr':
317-
brick_size = calculate_atr(atr_length, highs, lows, closes)
318-
print(brick_size)
341+
brick_size = _calculate_atr(atr_length, highs, lows, closes)
319342

320343
alpha = marketcolors['alpha']
321344

322-
uc = mcolors.to_rgba(marketcolors['candle'][ 'up' ], alpha)
323-
dc = mcolors.to_rgba(marketcolors['candle']['down'], alpha)
345+
uc = mcolors.to_rgba(marketcolors['candle'][ 'up' ], 1.0)
346+
dc = mcolors.to_rgba(marketcolors['candle']['down'], 1.0)
347+
euc = mcolors.to_rgba(marketcolors['edge'][ 'up' ], 1.0)
348+
edc = mcolors.to_rgba(marketcolors['edge']['down'], 1.0)
324349

325350
cdiff = [(closes[i+1] - closes[i])/brick_size for i in range(len(closes)-1)] # fill cdiff with close price change
326351

327352
bricks = [] # holds bricks, 1 for down bricks, -1 for up bricks
328353
new_dates = [] # holds the dates corresponding with the index
354+
new_volumes = [] # holds the volumes corresponding with the index. If more than one index for the same day then they all have the same volume.
329355

330356
prev_num = 0
331357
start_price = closes[0]
358+
359+
volume_cache = 0 # holds the volumes for the dates that were skipped
332360

333361

334362
for i in range(len(cdiff)):
335363
num_bricks = abs(int(round(cdiff[i], 0)))
336364

337365
if num_bricks != 0:
338366
new_dates.extend([dates[i]]*num_bricks)
367+
new_volumes.extend([volumes[i] + volume_cache]*num_bricks)
368+
volume_cache = 0
369+
else:
370+
volume_cache += volumes[i]
339371

340372
if cdiff[i] > 0:
341373
bricks.extend([1]*num_bricks)
@@ -344,11 +376,14 @@ def _construct_renko_collections(dates, highs, lows, renko_params, closes, marke
344376

345377
verts = []
346378
colors = []
379+
edge_colors = []
347380
for index, number in enumerate(bricks):
348381
if number == 1: # up brick
349382
colors.append(uc)
383+
edge_colors.append(euc)
350384
else: # down brick
351385
colors.append(dc)
386+
edge_colors.append(edc)
352387

353388
prev_num += number
354389
x, y = index, start_price + (prev_num * brick_size)
@@ -365,10 +400,11 @@ def _construct_renko_collections(dates, highs, lows, renko_params, closes, marke
365400
rectCollection = PolyCollection(verts,
366401
facecolors=colors,
367402
antialiaseds=useAA,
403+
edgecolors=edge_colors,
368404
linewidths=lw
369405
)
370406

371-
return (rectCollection, ), new_dates
407+
return (rectCollection, ), new_dates, new_volumes
372408

373409
from matplotlib.ticker import Formatter
374410
class IntegerIndexDateTimeFormatter(Formatter):

src/mplfinance/plotting.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mplfinance._utils import _construct_candlestick_collections
1717
from mplfinance._utils import _construct_renko_collections
1818

19+
from mplfinance._utils import renko_reformat_ydata
1920
from mplfinance._utils import _updown_colors
2021
from mplfinance._utils import IntegerIndexDateTimeFormatter
2122

@@ -222,12 +223,10 @@ def plot( data, **kwargs ):
222223
for apdict in addplot:
223224
if apdict['panel'] == 'lower':
224225
need_lower_panel = True
225-
break
226-
227-
ptype = config['type']
226+
break
228227

229228
# fig.add_axes( [left, bottom, width, height] ) ... numbers are fraction of fig
230-
if need_lower_panel or config['volume'] and ptype is not 'renko':
229+
if need_lower_panel or config['volume']:
231230
ax1 = fig.add_axes( [0.15, 0.38, 0.70, 0.50] )
232231
ax2 = fig.add_axes( [0.15, 0.18, 0.70, 0.20], sharex=ax1 )
233232
plt.xticks(rotation=45) # must do this after creation of axis, and
@@ -258,7 +257,7 @@ def plot( data, **kwargs ):
258257
else:
259258
fmtstring = '%b %d'
260259

261-
260+
ptype = config['type']
262261

263262
if ptype is not 'renko':
264263
if config['show_nontrading']:
@@ -270,10 +269,6 @@ def plot( data, **kwargs ):
270269

271270
ax1.xaxis.set_major_formatter(formatter)
272271

273-
274-
275-
renko_params = config['renko_params']
276-
277272
collections = None
278273
if ptype == 'candle' or ptype == 'candlestick':
279274
collections = _construct_candlestick_collections(xdates, opens, highs, lows, closes,
@@ -282,8 +277,8 @@ def plot( data, **kwargs ):
282277
collections = _construct_ohlc_collections(xdates, opens, highs, lows, closes,
283278
marketcolors=style['marketcolors'] )
284279
elif ptype == 'renko':
285-
renko_params = _process_kwargs(kwargs['renko_params'] if 'renko_params' in kwargs else dict(), _valid_renko_kwargs())
286-
collections, new_dates = _construct_renko_collections(dates, highs, lows, renko_params, closes,
280+
renko_params = _process_kwargs(config['renko_params'], _valid_renko_kwargs())
281+
collections, new_dates, volumes = _construct_renko_collections(dates, highs, lows, volumes, renko_params, closes,
287282
marketcolors=style['marketcolors'] )
288283

289284
formatter = IntegerIndexDateTimeFormatter(new_dates, fmtstring)
@@ -312,25 +307,23 @@ def plot( data, **kwargs ):
312307
mavc = None
313308

314309
for mav in mavgs:
315-
mavprices = data['Close'].rolling(mav).mean().values
310+
mavprices = data['Close'].rolling(mav).mean().values
311+
if ptype == 'renko':
312+
mavprices = renko_reformat_ydata(mavprices, new_dates, dates)
316313
if mavc:
317314
ax1.plot(xdates, mavprices, color=next(mavc))
318315
else:
319316
ax1.plot(xdates, mavprices)
320317

321-
322-
if ptype == 'renko':
323-
ax1.autoscale()
324-
else:
325-
avg_dist_between_points = (xdates[-1] - xdates[0]) / float(len(xdates))
326-
minx = xdates[0] - avg_dist_between_points
327-
maxx = xdates[-1] + avg_dist_between_points
328-
miny = min([low for low in lows if low != -1])
329-
maxy = max([high for high in highs if high != -1])
330-
corners = (minx, miny), (maxx, maxy)
331-
ax1.update_datalim(corners)
332-
333-
if config['volume'] and ptype is not 'renko':
318+
avg_dist_between_points = (xdates[-1] - xdates[0]) / float(len(xdates))
319+
minx = xdates[0] - avg_dist_between_points
320+
maxx = xdates[-1] + avg_dist_between_points
321+
miny = min([low for low in lows if low != -1])
322+
maxy = max([high for high in highs if high != -1])
323+
corners = (minx, miny), (maxx, maxy)
324+
ax1.update_datalim(corners)
325+
326+
if config['volume']:
334327
vup,vdown = style['marketcolors']['volume'].values()
335328
#-- print('vup,vdown=',vup,vdown)
336329
vcolors = _updown_colors(vup, vdown, opens, closes, use_prev_close=style['marketcolors']['vcdopcod'])
@@ -415,6 +408,9 @@ def plot( data, **kwargs ):
415408
if ax == ax4:
416409
used_ax4 = True
417410

411+
if ptype == 'renko':
412+
ydata = renko_reformat_ydata(ydata, new_dates, dates)
413+
418414
if apdict['scatter']:
419415
size = apdict['markersize']
420416
mark = apdict['marker']
@@ -449,7 +445,7 @@ def plot( data, **kwargs ):
449445
ax4.yaxis.set_label_position('right')
450446
ax4.yaxis.tick_right()
451447

452-
if need_lower_panel or config['volume'] and ptype is not 'renko':
448+
if need_lower_panel or config['volume']:
453449
ax1.spines['bottom'].set_linewidth(0.25)
454450
ax2.spines['top' ].set_linewidth(0.25)
455451
plt.setp(ax1.get_xticklabels(), visible=False)
@@ -476,7 +472,7 @@ def plot( data, **kwargs ):
476472

477473
ax1.set_ylabel(config['ylabel'])
478474

479-
if config['volume'] and ptype is not 'renko':
475+
if config['volume']:
480476
ax2.figure.canvas.draw() # This is needed to calculate offset
481477
offset = ax2.yaxis.get_major_formatter().get_offset()
482478
ax2.yaxis.offsetText.set_visible(False)

0 commit comments

Comments
 (0)