Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,16 @@ def scat2(x, y, by=None, ax=None, figsize=None):
_check_plot_works(scat2, 0, 1)
grouper = Series(np.repeat([1, 2, 3, 4, 5], 20), df.index)
_check_plot_works(scat2, 0, 1, by=grouper)
_check_plot_works(scat2, 0, 1, by=grouper, sharex=False,
sharey=True, xlim=(1, 3), ylim=(3, 5), color='red')

xrot, yrot = 30, 30
fig = scat2(0, 1, xlabelsize=xf, ylabelsize=yf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never finished this one eh? I'm fixing now...

for ax in fig.axes
ytick = ax.get_yticklabels()[0]
xtick = ax.get_xticklabels()[0]
self.assertAlmostEqual(ytick.get_fontsize(), yf)
self.assertAlmostEqual(xtick.get_fontsize(), xf)

@slow
def test_andrews_curves(self):
Expand Down
91 changes: 78 additions & 13 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,34 +1431,56 @@ def format_date_labels(ax, rot):
pass


def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False):
def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False,
xlabelsize=None, ylabelsize=None,
sharex=True, sharey=True, xlim=None, ylim=None,
**kwds):
"""
Draw scatter plot of the DataFrame's series using matplotlib / pylab.

Returns
-------
fig : matplotlib.Figure
Parameters
----------
data : Dataframe
x : column name of Dataframe for x axis
y : column name of Dataframe for y axis
by : column in the DataFrame to group by
ax : matplotlib axes object, default None
figsize :
grid : boolean, default True
Whether to show axis grid lines
xlabelsize : int, default None
If specified changes the x-axis label size
ylabelsize : int, default None
If specified changes the y-axis label size
sharex : bool, if True, the X axis will be shared amongst all subplots.
sharey : bool, if True, the Y axis will be shared amongst all subplots.
xlim : 2-tuple/list
ylim : 2-tuple/list
kwds : other plotting keyword arguments
To be passed to scatter function
"""
import matplotlib.pyplot as plt

def plot_group(group, ax):
def plot_group(group, ax, *kwds):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you didn't try running this code?

xvals = group[x].values
yvals = group[y].values
ax.scatter(xvals, yvals)
ax.grid(grid)
ax.scatter(xvals, yvals, **kwds)
_decorate_axes(ax, grid=grid, xlabelsize=xlabelsize,
ylabelsize=ylabelsize, xlim=xlim, ylim=ylim)

if by is not None:
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax)
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize,
sharex=sharex, sharey=sharey, ax=ax, **kwds)
else:
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
plot_group(data, ax)
ax.set_ylabel(com._stringify(y))
ax.set_xlabel(com._stringify(x))

ax.grid(grid)
plot_group(data, ax, **kwds)
_decorate_axes(ax, grid=grid, xlabelsize=xlabelsize,
ylabelsize=ylabelsize, xlim=xlim, ylim=ylim,
xlabel=com._stringify(x), ylabel=com._stringify(y))

return fig

Expand Down Expand Up @@ -1735,6 +1757,49 @@ def _get_layout(nplots):
else:
return k, k

def _decorate_axes(axes, title=None, legend=None,
xlim=None, ylim=None, grid=None,
xticks=None, yticks=None, xticklabels=None, yticklabels=None,
xlabelsize=None, ylabelsize=None, xrot=None, yrot=None,
xlabel=None, ylabel=None):
import matplotlib.pyplot as plt
import matplotlib.axes
assert isinstance(axes, matplotlib.axes.SubplotBase)

if title is not None:
axes.set_title(title)

if legend == True:
axes.legend()
elif isinstance(legend, dict):
axes.legend(**legend)

if xticks is not None:
axes.xaxis.set_ticks(xticks)

if xticklabels is not None:
axes.xaxis.set_ticklabels(xtickslabels)
if xlabelsize is not None:
plt.setp(axes.get_xticklabels(), fontsize=xlabelsize)
if xrot is not None:
plt.setp(axes.get_xticklabels(), rotation=xrot)

if xlabel is not None:
axes.set_xlabel(xlabel)

if yticks is not None:
axes.yaxis.set_ticks(yticks)

if yticklabels is not None:
axes.yaxis.set_ticklabels(yticklabels)
if ylabelsize is not None:
plt.setp(axes.get_yticklabels(), fontsize=ylabelsize)
if yrot is not None:
plt.setp(axes.get_yticklabels(), rotation=yrot)

if ylabel is not None:
axes.set_ylabel(ylabel)

# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0

def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
Expand Down