diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index f2fd98169b585..3c28038927401 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -387,6 +387,16 @@ def scat2(x, y, by=None, ax=None, figsize=None): _check_plot_works(scat2, 0, 1) grouper = Series(np.repeat([1, 2, 3, 4, 5], 20), df.index) _check_plot_works(scat2, 0, 1, by=grouper) + _check_plot_works(scat2, 0, 1, by=grouper, sharex=False, + sharey=True, xlim=(1, 3), ylim=(3, 5), color='red') + + xrot, yrot = 30, 30 + fig = scat2(0, 1, xlabelsize=xf, ylabelsize=yf + for ax in fig.axes + ytick = ax.get_yticklabels()[0] + xtick = ax.get_xticklabels()[0] + self.assertAlmostEqual(ytick.get_fontsize(), yf) + self.assertAlmostEqual(xtick.get_fontsize(), xf) @slow def test_andrews_curves(self): diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 60ed0c70d516b..c5ba80ed9b971 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1431,34 +1431,56 @@ def format_date_labels(ax, rot): pass -def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False): +def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False, + xlabelsize=None, ylabelsize=None, + sharex=True, sharey=True, xlim=None, ylim=None, + **kwds): """ + Draw scatter plot of the DataFrame's series using matplotlib / pylab. - Returns - ------- - fig : matplotlib.Figure + Parameters + ---------- + data : Dataframe + x : column name of Dataframe for x axis + y : column name of Dataframe for y axis + by : column in the DataFrame to group by + ax : matplotlib axes object, default None + figsize : + grid : boolean, default True + Whether to show axis grid lines + xlabelsize : int, default None + If specified changes the x-axis label size + ylabelsize : int, default None + If specified changes the y-axis label size + sharex : bool, if True, the X axis will be shared amongst all subplots. + sharey : bool, if True, the Y axis will be shared amongst all subplots. + xlim : 2-tuple/list + ylim : 2-tuple/list + kwds : other plotting keyword arguments + To be passed to scatter function """ import matplotlib.pyplot as plt - def plot_group(group, ax): + def plot_group(group, ax, *kwds): xvals = group[x].values yvals = group[y].values - ax.scatter(xvals, yvals) - ax.grid(grid) + ax.scatter(xvals, yvals, **kwds) + _decorate_axes(ax, grid=grid, xlabelsize=xlabelsize, + ylabelsize=ylabelsize, xlim=xlim, ylim=ylim) if by is not None: - fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax) + fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, + sharex=sharex, sharey=sharey, ax=ax, **kwds) else: if ax is None: fig = plt.figure() ax = fig.add_subplot(111) else: fig = ax.get_figure() - plot_group(data, ax) - ax.set_ylabel(com._stringify(y)) - ax.set_xlabel(com._stringify(x)) - - ax.grid(grid) + plot_group(data, ax, **kwds) + _decorate_axes(ax, grid=grid, xlabelsize=xlabelsize, + ylabelsize=ylabelsize, xlim=xlim, ylim=ylim, + xlabel=com._stringify(x), ylabel=com._stringify(y)) return fig @@ -1735,6 +1757,49 @@ def _get_layout(nplots): else: return k, k +def _decorate_axes(axes, title=None, legend=None, + xlim=None, ylim=None, grid=None, + xticks=None, yticks=None, xticklabels=None, yticklabels=None, + xlabelsize=None, ylabelsize=None, xrot=None, yrot=None, + xlabel=None, ylabel=None): + import matplotlib.pyplot as plt + import matplotlib.axes + assert isinstance(axes, matplotlib.axes.SubplotBase) + + if title is not None: + axes.set_title(title) + + if legend == True: + axes.legend() + elif isinstance(legend, dict): + axes.legend(**legend) + + if xticks is not None: + axes.xaxis.set_ticks(xticks) + + if xticklabels is not None: + axes.xaxis.set_ticklabels(xtickslabels) + if xlabelsize is not None: + plt.setp(axes.get_xticklabels(), fontsize=xlabelsize) + if xrot is not None: + plt.setp(axes.get_xticklabels(), rotation=xrot) + + if xlabel is not None: + axes.set_xlabel(xlabel) + + if yticks is not None: + axes.yaxis.set_ticks(yticks) + + if yticklabels is not None: + axes.yaxis.set_ticklabels(yticklabels) + if ylabelsize is not None: + plt.setp(axes.get_yticklabels(), fontsize=ylabelsize) + if yrot is not None: + plt.setp(axes.get_yticklabels(), rotation=yrot) + + if ylabel is not None: + axes.set_ylabel(ylabel) + # copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0 def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,