diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 0724799ced6f2..708f8143de3d5 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -460,6 +460,10 @@ def test_parallel_coordinates(self): path = os.path.join(curpath(), 'data/iris.csv') df = read_csv(path) _check_plot_works(parallel_coordinates, df, 'Name') + _check_plot_works(parallel_coordinates, df, 'Name', + colors=('#556270', '#4ECDC4', '#C7F464')) + _check_plot_works(parallel_coordinates, df, 'Name', + colors=['dodgerblue', 'aquamarine', 'seagreen']) @slow def test_radviz(self): diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 2e6faf5eb9362..c49d150aabd8a 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -411,7 +411,8 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds): return fig -def parallel_coordinates(data, class_column, cols=None, ax=None, **kwds): +def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None, + **kwds): """Parallel coordinates plotting. Parameters: @@ -420,6 +421,7 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, **kwds): class_column: Column name containing class names cols: A list of column names to use, optional ax: matplotlib axis object, optional + colors: A list or tuple of colors to use for the different classes, optional kwds: A list of keywords for matplotlib plot method Returns: @@ -449,6 +451,14 @@ def random_color(column): if ax == None: ax = plt.gca() + # if user has not specified colors to use, choose at random + if colors is None: + colors = dict((kls, random_color(kls)) for kls in classes) + else: + if len(colors) != len(classes): + raise ValueError('Number of colors must match number of classes') + colors = dict((kls, colors[i]) for i, kls in enumerate(classes)) + for i in range(n): row = df.irow(i).values y = row @@ -456,10 +466,10 @@ def random_color(column): if com.pprint_thing(kls) not in used_legends: label = com.pprint_thing(kls) used_legends.add(label) - ax.plot(x, y, color=random_color(kls), + ax.plot(x, y, color=colors[kls], label=label, **kwds) else: - ax.plot(x, y, color=random_color(kls), **kwds) + ax.plot(x, y, color=colors[kls], **kwds) for i in range(ncols): ax.axvline(i, linewidth=1, color='black')