Skip to content

Commit f16a1b1

Browse files
addplot cleanup
1 parent d601b57 commit f16a1b1

File tree

1 file changed

+113
-106
lines changed

1 file changed

+113
-106
lines changed

src/mplfinance/plotting.py

Lines changed: 113 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -468,44 +468,9 @@ def plot( data, **kwargs ):
468468
if apdict['y_on_right'] is not None:
469469
panels.at[panid,'y_on_right'] = apdict['y_on_right']
470470

471-
#--------------------------------------------------------------#
472-
# Note: _auto_secondary_y() sets the 'magnitude' column in the
473-
# `panels` dataframe, which is needed for automatically
474-
# determining if secondary_y is needed. Therefore we call
475-
# _auto_secondary_y() for *all* addplots, even those that
476-
# are set to True or False (not 'auto') for secondary_y
477-
# because their magnitudes may be needed if *any* apdicts
478-
# contain secondary_y='auto'.
479-
# In theory we could first loop through all apdicts to see
480-
# if any have secondary_y='auto', but since that is the
481-
# default value, we will just assume we have at least one.
482-
483-
apdata = apdict['data']
484471
aptype = apdict['type']
485-
486472
if aptype == 'ohlc' or aptype == 'candle':
487-
#import pdb; pdb.set_trace()
488-
if not isinstance(apdata,pd.DataFrame):
489-
raise TypeError('addplot type "'+aptype+'" MUST be accompanied by addplot data of type `pd.DataFrame`')
490-
d,o,h,l,c,v = _check_and_prepare_data(apdata,config)
491-
collections = _construct_mpf_collections(aptype,d,xdates,o,h,l,c,v,config,style)
492-
lo = math.log(max(math.fabs(np.nanmin(l)),1e-7),10) - 0.5
493-
hi = math.log(max(math.fabs(np.nanmax(h)),1e-7),10) + 0.5
494-
secondary_y = _auto_secondary_y( panels, panid, lo, hi )
495-
if 'auto' != apdict['secondary_y']:
496-
secondary_y = apdict['secondary_y']
497-
if secondary_y:
498-
ax = panels.at[panid,'axes'][1]
499-
panels.at[panid,'used2nd'] = True
500-
else:
501-
ax = panels.at[panid,'axes'][0]
502-
for coll in collections:
503-
ax.add_collection(coll)
504-
if apdict['mav'] is not None:
505-
apmavprices = _plot_mav(ax,config,xdates,c,apdict['mav'])
506-
#datalim = (minx, min(l)), (maxx, max(h))
507-
#ax.update_datalim(datalim)
508-
ax.autoscale_view()
473+
ax = _addplot_collections(panid,panels,apdict,xdates,config)
509474
if (apdict['ylabel'] is not None):
510475
ax.set_ylabel(apdict['ylabel'])
511476
if apdict['ylim'] is not None:
@@ -517,77 +482,30 @@ def plot( data, **kwargs ):
517482
#else:
518483
# corners = (minx, miny), (maxx, maxy)
519484
# ax.update_datalim(corners)
520-
continue
521-
522-
if isinstance(apdata,list) and not isinstance(apdata[0],(float,int)):
523-
raise TypeError('apdata is list but NOT of float or int')
524-
if isinstance(apdata,pd.DataFrame):
525-
havedf = True
526-
else:
527-
havedf = False # must be a single series or array
528-
apdata = [apdata,] # make it iterable
529-
530-
for column in apdata:
531-
if havedf:
532-
ydata = apdata.loc[:,column]
485+
else:
486+
apdata = apdict['data']
487+
if isinstance(apdata,list) and not isinstance(apdata[0],(float,int)):
488+
raise TypeError('apdata is list but NOT of float or int')
489+
if isinstance(apdata,pd.DataFrame):
490+
havedf = True
533491
else:
534-
ydata = column
535-
secondary_y = False
536-
if apdict['secondary_y'] == 'auto':
537-
yd = [y for y in ydata if not math.isnan(y)]
538-
ymhi = math.log(max(math.fabs(np.nanmax(yd)),1e-7),10)
539-
ymlo = math.log(max(math.fabs(np.nanmin(yd)),1e-7),10)
540-
secondary_y = _auto_secondary_y( panels, panid, ymlo, ymhi )
541-
else:
542-
secondary_y = apdict['secondary_y']
543-
#print("apdict['secondary_y'] says secondary_y is",secondary_y)
544-
545-
if secondary_y:
546-
ax = panels.at[panid,'axes'][1]
547-
panels.at[panid,'used2nd'] = True
548-
else:
549-
ax = panels.at[panid,'axes'][0]
550-
551-
aptype = apdict['type']
552-
if aptype == 'scatter':
553-
size = apdict['markersize']
554-
mark = apdict['marker']
555-
color = apdict['color']
556-
alpha = apdict['alpha']
557-
if isinstance(mark,(list,tuple,np.ndarray)):
558-
_mscatter(xdates,ydata,ax=ax,m=mark,s=size,color=color,alpha=alpha)
559-
else:
560-
ax.scatter(xdates,ydata,s=size,marker=mark,color=color,alpha=alpha)
561-
elif aptype == 'bar':
562-
width = 0.8 if apdict['width'] is None else apdict['width']
563-
bottom = apdict['bottom']
564-
color = apdict['color']
565-
alpha = apdict['alpha']
566-
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha)
567-
elif aptype == 'line':
568-
ls = apdict['linestyle']
569-
color = apdict['color']
570-
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
571-
alpha = apdict['alpha']
572-
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha)
573-
else:
574-
raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.')
575-
576-
if apdict['mav'] is not None:
577-
apmavprices = _plot_mav(ax,config,xdates,ydata,apdict['mav'])
578-
579-
if (apdict["ylabel"] is not None):
580-
ax.set_ylabel(apdict["ylabel"])
581-
582-
if apdict['ylim'] is not None:
583-
ax.set_ylim(apdict['ylim'][0],apdict['ylim'][1])
584-
#elif config['tight_layout']:
585-
# ax.set_xlim(minx,maxx)
586-
# ydelta = 0.01 * (maxy-miny)
587-
# ax.set_ylim(miny-ydelta,maxy+ydelta)
588-
#else:
589-
# corners = (minx, miny), (maxx, maxy)
590-
# ax.update_datalim(corners)
492+
havedf = False # must be a single series or array
493+
apdata = [apdata,] # make it iterable
494+
495+
for column in apdata:
496+
ydata = apdata.loc[:,column] if havedf else column
497+
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config)
498+
if (apdict["ylabel"] is not None):
499+
ax.set_ylabel(apdict["ylabel"])
500+
if apdict['ylim'] is not None:
501+
ax.set_ylim(apdict['ylim'][0],apdict['ylim'][1])
502+
#elif config['tight_layout']:
503+
# ax.set_xlim(minx,maxx)
504+
# ydelta = 0.01 * (maxy-miny)
505+
# ax.set_ylim(miny-ydelta,maxy+ydelta)
506+
#else:
507+
# corners = (minx, miny), (maxx, maxy)
508+
# ax.update_datalim(corners)
591509

592510
if config['fill_between'] is not None:
593511
fb = config['fill_between']
@@ -704,6 +622,95 @@ def plot( data, **kwargs ):
704622
# print('rcpdfhead(3)=',rcpdf.head(3))
705623
# return # rcpdf
706624

625+
def _addplot_collections(panid,panels,apdict,xdates,config):
626+
627+
apdata = apdict['data']
628+
aptype = apdict['type']
629+
630+
#--------------------------------------------------------------#
631+
# Note: _auto_secondary_y() sets the 'magnitude' column in the
632+
# `panels` dataframe, which is needed for automatically
633+
# determining if secondary_y is needed. Therefore we call
634+
# _auto_secondary_y() for *all* addplots, even those that
635+
# are set to True or False (not 'auto') for secondary_y
636+
# because their magnitudes may be needed if *any* apdicts
637+
# contain secondary_y='auto'.
638+
# In theory we could first loop through all apdicts to see
639+
# if any have secondary_y='auto', but since that is the
640+
# default value, we will just assume we have at least one.
641+
642+
valid_apc_types = ['ohlc','candle']
643+
if aptype not in valid_apc_types:
644+
raise TypeError('Invalid aptype='+str(aptype)+'. Must be one of '+str(valid_apc_types))
645+
if not isinstance(apdata,pd.DataFrame):
646+
raise TypeError('addplot type "'+aptype+'" MUST be accompanied by addplot data of type `pd.DataFrame`')
647+
d,o,h,l,c,v = _check_and_prepare_data(apdata,config)
648+
collections = _construct_mpf_collections(aptype,d,xdates,o,h,l,c,v,config,config['style'])
649+
lo = math.log(max(math.fabs(np.nanmin(l)),1e-7),10) - 0.5
650+
hi = math.log(max(math.fabs(np.nanmax(h)),1e-7),10) + 0.5
651+
secondary_y = _auto_secondary_y( panels, panid, lo, hi )
652+
if 'auto' != apdict['secondary_y']:
653+
secondary_y = apdict['secondary_y']
654+
if secondary_y:
655+
ax = panels.at[panid,'axes'][1]
656+
panels.at[panid,'used2nd'] = True
657+
else:
658+
ax = panels.at[panid,'axes'][0]
659+
for coll in collections:
660+
ax.add_collection(coll)
661+
if apdict['mav'] is not None:
662+
apmavprices = _plot_mav(ax,config,xdates,c,apdict['mav'])
663+
ax.autoscale_view()
664+
return ax
665+
666+
def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
667+
secondary_y = False
668+
if apdict['secondary_y'] == 'auto':
669+
yd = [y for y in ydata if not math.isnan(y)]
670+
ymhi = math.log(max(math.fabs(np.nanmax(yd)),1e-7),10)
671+
ymlo = math.log(max(math.fabs(np.nanmin(yd)),1e-7),10)
672+
secondary_y = _auto_secondary_y( panels, panid, ymlo, ymhi )
673+
else:
674+
secondary_y = apdict['secondary_y']
675+
#print("apdict['secondary_y'] says secondary_y is",secondary_y)
676+
677+
if secondary_y:
678+
ax = panels.at[panid,'axes'][1]
679+
panels.at[panid,'used2nd'] = True
680+
else:
681+
ax = panels.at[panid,'axes'][0]
682+
683+
aptype = apdict['type']
684+
if aptype == 'scatter':
685+
size = apdict['markersize']
686+
mark = apdict['marker']
687+
color = apdict['color']
688+
alpha = apdict['alpha']
689+
if isinstance(mark,(list,tuple,np.ndarray)):
690+
_mscatter(xdates,ydata,ax=ax,m=mark,s=size,color=color,alpha=alpha)
691+
else:
692+
ax.scatter(xdates,ydata,s=size,marker=mark,color=color,alpha=alpha)
693+
elif aptype == 'bar':
694+
width = 0.8 if apdict['width'] is None else apdict['width']
695+
bottom = apdict['bottom']
696+
color = apdict['color']
697+
alpha = apdict['alpha']
698+
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha)
699+
elif aptype == 'line':
700+
ls = apdict['linestyle']
701+
color = apdict['color']
702+
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
703+
alpha = apdict['alpha']
704+
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha)
705+
else:
706+
raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.')
707+
708+
if apdict['mav'] is not None:
709+
apmavprices = _plot_mav(ax,config,xdates,ydata,apdict['mav'])
710+
711+
return ax
712+
713+
707714
def _set_ylabels_side(ax_pri,ax_sec,primary_on_right):
708715
# put the primary axis on one side,
709716
# and the twinx() on the "other" side:

0 commit comments

Comments
 (0)