Skip to content

Commit 3becf41

Browse files
committed
Preliminary done
1 parent d777f43 commit 3becf41

File tree

3 files changed

+95
-21
lines changed

3 files changed

+95
-21
lines changed

src/mplfinance/_arg_validators.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,17 @@ def _valid_mav(value, is_period=True):
142142
return True
143143
return False
144144

145+
def _colors_validator(value):
146+
if not isinstance(value, list):
147+
return False
148+
149+
for v in value:
150+
if v:
151+
if not (isinstance(v, dict) or isinstance(v, str)):
152+
return False
153+
154+
return True
155+
145156

146157
def _hlines_validator(value):
147158
if isinstance(value,dict):

src/mplfinance/_utils.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from six.moves import zip
2222

23-
def _check_input(opens, closes, highs, lows):
23+
def _check_input(opens, closes, highs, lows, colors=None):
2424
"""Checks that *opens*, *highs*, *lows* and *closes* have the same length.
2525
NOTE: this code assumes if any value open, high, low, close is
2626
missing (*-1*) they all are missing
@@ -46,6 +46,10 @@ def _check_input(opens, closes, highs, lows):
4646
if not same_length:
4747
raise ValueError('O,H,L,C must have the same length!')
4848

49+
if colors:
50+
if len(opens) != len(colors):
51+
raise ValueError('O,H,L,C and Colors must have the same length!')
52+
4953
o = np.where(np.isnan(opens))[0]
5054
h = np.where(np.isnan(highs))[0]
5155
l = np.where(np.isnan(lows))[0]
@@ -85,11 +89,11 @@ def _check_and_convert_xlim_configuration(data, config):
8589
return xlim
8690

8791

88-
def _construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style):
92+
def _construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style,colors):
8993
collections = None
9094
if ptype == 'candle' or ptype == 'candlestick':
9195
collections = _construct_candlestick_collections(xdates, opens, highs, lows, closes,
92-
marketcolors=style['marketcolors'],config=config )
96+
marketcolors=style['marketcolors'],config=config, colors=colors )
9397

9498
elif ptype =='hollow_and_filled':
9599
collections = _construct_hollow_candlestick_collections(xdates, opens, highs, lows, closes,
@@ -176,16 +180,45 @@ def coalesce_volume_dates(in_volumes, in_dates, indexes):
176180
return volumes, dates
177181

178182

179-
def _updown_colors(upcolor,downcolor,opens,closes,use_prev_close=False):
180-
if upcolor == downcolor:
181-
return upcolor
182-
cmap = {True : upcolor, False : downcolor}
183-
if not use_prev_close:
184-
return [ cmap[opn < cls] for opn,cls in zip(opens,closes) ]
183+
def _updown_colors(upcolor,downcolor,opens,closes,use_prev_close=False,colors=None):
184+
if not colors:
185+
if upcolor == downcolor:
186+
return upcolor
187+
cmap = {True : upcolor, False : downcolor}
188+
if not use_prev_close:
189+
return [ cmap[opn < cls] for opn,cls in zip(opens,closes) ]
190+
else:
191+
first = cmap[opens[0] < closes[0]]
192+
_list = [ cmap[pre < cls] for cls,pre in zip(closes[1:], closes) ]
193+
return [first] + _list
185194
else:
186-
first = cmap[opens[0] < closes[0]]
187-
_list = [ cmap[pre < cls] for cls,pre in zip(closes[1:], closes) ]
188-
return [first] + _list
195+
cmap = {True: 'up', False: 'down'}
196+
default = {'up': upcolor, 'down': downcolor}
197+
custom = []
198+
if not use_prev_close:
199+
for i in range(len(opens)):
200+
opn = opens[i]
201+
cls = closes[i]
202+
if colors[i]:
203+
custom.append(colors[i][cmap[opn < cls]])
204+
else:
205+
custom.append(default[cmap[opn < cls]])
206+
else:
207+
if color[0]:
208+
custom.append(colors[0][cmap[opens[0] < closes[0]]])
209+
else:
210+
custom.append(default[cmap[opens[0] < closes[0]]])
211+
212+
for i in range(len(closes) - 1):
213+
pre = closes[1:][i]
214+
cls = closes[i]
215+
if colors[i]:
216+
custom.append(colors[i][cmap[pre < cls]])
217+
else:
218+
custom.append(default[cmap[pre < cls]])
219+
220+
return custom
221+
189222

190223

191224
def _updownhollow_colors(upcolor,downcolor,hollowcolor,opens,closes):
@@ -525,7 +558,7 @@ def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=
525558
return [rangeCollection, openCollection, closeCollection]
526559

527560

528-
def _construct_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None):
561+
def _construct_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None, colors=None):
529562
"""Represent the open, close as a bar line and high low range as a
530563
vertical line.
531564
@@ -552,8 +585,8 @@ def _construct_candlestick_collections(dates, opens, highs, lows, closes, market
552585
ret : list
553586
(lineCollection, barCollection)
554587
"""
555-
556-
_check_input(opens, highs, lows, closes)
588+
589+
_check_input(opens, highs, lows, closes, colors)
557590

558591
if marketcolors is None:
559592
marketcolors = _get_mpfstyle('classic')['marketcolors']
@@ -581,17 +614,34 @@ def _construct_candlestick_collections(dates, opens, highs, lows, closes, market
581614

582615
alpha = marketcolors['alpha']
583616

617+
candle_c = None
618+
wick_c = None
619+
edge_c = None
620+
if colors:
621+
candle_c = []
622+
wick_c = []
623+
edge_c = []
624+
for color in colors:
625+
if color:
626+
candle_c.append({'up': mcolors.to_rgba(color['candle']['up'], alpha), 'down': mcolors.to_rgba(color['candle']['down'], alpha)})
627+
wick_c.append({'up': mcolors.to_rgba(color['wick']['up'], 1), 'down': mcolors.to_rgba(color['wick']['down'], 1)})
628+
edge_c.append({'up': mcolors.to_rgba(color['edge']['up'], 1), 'down': mcolors.to_rgba(color['edge']['down'], 1)})
629+
else:
630+
candle_c.append(None)
631+
wick_c.append(None)
632+
edge_c.append(None)
633+
584634
uc = mcolors.to_rgba(marketcolors['candle'][ 'up' ], alpha)
585635
dc = mcolors.to_rgba(marketcolors['candle']['down'], alpha)
586-
colors = _updown_colors(uc, dc, opens, closes)
636+
colors = _updown_colors(uc, dc, opens, closes, colors=candle_c)
587637

588638
uc = mcolors.to_rgba(marketcolors['edge'][ 'up' ], 1.0)
589639
dc = mcolors.to_rgba(marketcolors['edge']['down'], 1.0)
590-
edgecolor = _updown_colors(uc, dc, opens, closes)
640+
edgecolor = _updown_colors(uc, dc, opens, closes, colors=edge_c)
591641

592642
uc = mcolors.to_rgba(marketcolors['wick'][ 'up' ], 1.0)
593643
dc = mcolors.to_rgba(marketcolors['wick']['down'], 1.0)
594-
wickcolor = _updown_colors(uc, dc, opens, closes)
644+
wickcolor = _updown_colors(uc, dc, opens, closes, colors=wick_c)
595645

596646
lw = config['_width_config']['candle_linewidth']
597647

src/mplfinance/plotting.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from mplfinance._arg_validators import _scale_padding_validator, _yscale_validator
4141
from mplfinance._arg_validators import _valid_panel_id, _check_for_external_axes
4242
from mplfinance._arg_validators import _xlim_validator
43+
from mplfinance._arg_validators import _colors_validator
4344

4445
from mplfinance._panels import _build_panels
4546
from mplfinance._panels import _set_ticks_on_bottom_panel_only
@@ -49,6 +50,8 @@
4950
from mplfinance._helpers import _num_or_seq_of_num
5051
from mplfinance._helpers import _adjust_color_brightness
5152

53+
from mplfinance._styles import make_marketcolors
54+
5255
VALID_PMOVE_TYPES = ['renko', 'pnf']
5356

5457
DEFAULT_FIGRATIO = (8.00,5.75)
@@ -125,6 +128,9 @@ def _valid_plot_kwargs():
125128

126129
'marketcolors' : { 'Default' : None, # use 'style' for default, instead.
127130
'Validator' : lambda value: isinstance(value,dict) },
131+
132+
'colors' : { 'Default' : None, # use default style instead.
133+
'Validator' : lambda value: _colors_validator(value) },
128134

129135
'no_xgaps' : { 'Default' : True, # None means follow default logic below:
130136
'Validator' : lambda value: _warn_no_xgaps_deprecated(value) },
@@ -391,14 +397,21 @@ def plot( data, **kwargs ):
391397
rwc = config['return_width_config']
392398
if isinstance(rwc,dict) and len(rwc)==0:
393399
config['return_width_config'].update(config['_width_config'])
394-
400+
401+
if config['colors']:
402+
colors = config['colors']
403+
for c in range(len(colors)):
404+
if isinstance(colors[c], str):
405+
config['colors'][c] = make_marketcolors(up=colors[c], down=colors[c], edge=colors[c], wick=colors[c])
406+
else:
407+
config['colors'] = None
395408

396409
collections = None
397410
if ptype == 'line':
398411
lw = config['_width_config']['line_width']
399412
axA1.plot(xdates, closes, color=config['linecolor'], linewidth=lw)
400413
else:
401-
collections =_construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style)
414+
collections =_construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style,config['colors'])
402415

403416
if ptype in VALID_PMOVE_TYPES:
404417
collections, calculated_values = collections
@@ -858,7 +871,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
858871
if not isinstance(apdata,pd.DataFrame):
859872
raise TypeError('addplot type "'+aptype+'" MUST be accompanied by addplot data of type `pd.DataFrame`')
860873
d,o,h,l,c,v = _check_and_prepare_data(apdata,config)
861-
collections = _construct_mpf_collections(aptype,d,xdates,o,h,l,c,v,config,config['style'])
874+
collections = _construct_mpf_collections(aptype,d,xdates,o,h,l,c,v,config,config['style'],config['colors'])
862875

863876
if not external_axes_mode:
864877
lo = math.log(max(math.fabs(np.nanmin(l)),1e-7),10) - 0.5

0 commit comments

Comments
 (0)