Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_bar_linewidth(self):
self.assert_(r.get_linewidth() == 2)

@slow
def test_1rotation(self):
def test_rotation(self):
df = DataFrame(np.random.randn(5, 5))
ax = df.plot(rot=30)
for l in ax.get_xticklabels():
Expand Down Expand Up @@ -447,6 +447,24 @@ def test_style_by_column(self):
for i, l in enumerate(ax.get_lines()[:len(markers)]):
self.assertEqual(l.get_marker(), markers[i])

@slow
def test_line_colors(self):
import matplotlib.pyplot as plt

custom_colors = 'rgcby'

plt.close('all')
df = DataFrame(np.random.randn(5, 5))

ax = df.plot(color=custom_colors)

lines = ax.get_lines()
for i, l in enumerate(lines):
xp = custom_colors[i]
rs = l.get_color()
self.assert_(xp == rs)


class TestDataFrameGroupByPlots(unittest.TestCase):

@classmethod
Expand Down
35 changes: 21 additions & 14 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=E1101
from itertools import izip
import datetime
import warnings
import re

import numpy as np
Expand Down Expand Up @@ -852,6 +853,14 @@ class LinePlot(MPLPlot):
def __init__(self, data, **kwargs):
self.mark_right = kwargs.pop('mark_right', True)
MPLPlot.__init__(self, data, **kwargs)
if 'color' not in self.kwds and 'colors' in self.kwds:
warnings.warn(("'colors' is being deprecated. Please use 'color'"
"instead of 'colors'"))
colors = self.kwds.pop('colors')
self.kwds['color'] = colors
if 'color' in self.kwds and isinstance(self.data, Series):
#support series.plot(color='green')
self.kwds['color'] = [self.kwds['color']]

def _index_freq(self):
from pandas.core.frame import DataFrame
Expand Down Expand Up @@ -889,14 +898,12 @@ def _use_dynamic_x(self):
def _get_colors(self):
import matplotlib.pyplot as plt
cycle = ''.join(plt.rcParams.get('axes.color_cycle', list('bgrcmyk')))
has_colors = 'colors' in self.kwds
colors = self.kwds.pop('colors', cycle)
return has_colors, colors

def _maybe_add_color(self, has_colors, colors, kwds, style, i):
if (not has_colors and
(style is None or re.match('[a-z]+', style) is None)
and 'color' not in kwds):
has_colors = 'color' in self.kwds
colors = self.kwds.get('color', cycle)
return colors

def _maybe_add_color(self, colors, kwds, style, i):
if style is None or re.match('[a-z]+', style) is None:
kwds['color'] = colors[i % len(colors)]

def _make_plot(self):
Expand All @@ -910,13 +917,13 @@ def _make_plot(self):
x = self._get_xticks(convert_period=True)

plotf = self._get_plot_function()
has_colors, colors = self._get_colors()
colors = self._get_colors()

for i, (label, y) in enumerate(self._iter_data()):
ax = self._get_ax(i)
style = self._get_style(i, label)
kwds = self.kwds.copy()
self._maybe_add_color(has_colors, colors, kwds, style, i)
self._maybe_add_color(colors, kwds, style, i)

label = com.pprint_thing(label) # .encode('utf-8')

Expand Down Expand Up @@ -944,7 +951,7 @@ def _make_plot(self):
def _make_ts_plot(self, data, **kwargs):
from pandas.tseries.plotting import tsplot
kwargs = kwargs.copy()
has_colors, colors = self._get_colors()
colors = self._get_colors()

plotf = self._get_plot_function()
lines = []
Expand All @@ -960,7 +967,7 @@ def to_leg_label(label, i):
style = self.style or ''
label = com.pprint_thing(self.label)
kwds = kwargs.copy()
self._maybe_add_color(has_colors, colors, kwds, style, 0)
self._maybe_add_color(colors, kwds, style, 0)

newlines = tsplot(data, plotf, ax=ax, label=label,
style=self.style, **kwds)
Expand All @@ -975,7 +982,7 @@ def to_leg_label(label, i):
style = self._get_style(i, col)
kwds = kwargs.copy()

self._maybe_add_color(has_colors, colors, kwds, style, i)
self._maybe_add_color(colors, kwds, style, i)

newlines = tsplot(data[col], plotf, ax=ax, label=label,
style=style, **kwds)
Expand Down Expand Up @@ -1096,7 +1103,7 @@ def f(ax, x, y, w, start=None, **kwds):
return f

def _make_plot(self):
colors = self.kwds.get('color', 'brgyk')
colors = self.kwds.pop('color', 'brgyk')
rects = []
labels = []

Expand Down