Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,10 @@ def test_parallel_coordinates(self):
path = os.path.join(curpath(), 'data/iris.csv')
df = read_csv(path)
_check_plot_works(parallel_coordinates, df, 'Name')
_check_plot_works(parallel_coordinates, df, 'Name',
colors=('#556270', '#4ECDC4', '#C7F464'))
_check_plot_works(parallel_coordinates, df, 'Name',
colors=['dodgerblue', 'aquamarine', 'seagreen'])

@slow
def test_radviz(self):
Expand Down
16 changes: 13 additions & 3 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
return fig


def parallel_coordinates(data, class_column, cols=None, ax=None, **kwds):
def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None,
**kwds):
"""Parallel coordinates plotting.

Parameters:
Expand All @@ -420,6 +421,7 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, **kwds):
class_column: Column name containing class names
cols: A list of column names to use, optional
ax: matplotlib axis object, optional
colors: A list or tuple of colors to use for the different classes, optional
kwds: A list of keywords for matplotlib plot method

Returns:
Expand Down Expand Up @@ -449,17 +451,25 @@ def random_color(column):
if ax == None:
ax = plt.gca()

# if user has not specified colors to use, choose at random
if colors is None:
colors = dict((kls, random_color(kls)) for kls in classes)
else:
if len(colors) != len(classes):
raise ValueError('Number of colors must match number of classes')
colors = dict((kls, colors[i]) for i, kls in enumerate(classes))

for i in range(n):
row = df.irow(i).values
y = row
kls = class_col.iget_value(i)
if com.pprint_thing(kls) not in used_legends:
label = com.pprint_thing(kls)
used_legends.add(label)
ax.plot(x, y, color=random_color(kls),
ax.plot(x, y, color=colors[kls],
label=label, **kwds)
else:
ax.plot(x, y, color=random_color(kls), **kwds)
ax.plot(x, y, color=colors[kls], **kwds)

for i in range(ncols):
ax.axvline(i, linewidth=1, color='black')
Expand Down