diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 47c2d414cce37..744cdd8a5020b 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1611,6 +1611,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, figsize=None, sharex=True, sharey=True, layout=None, rot=0, ax=None): from pandas.core.frame import DataFrame + import matplotlib.pyplot as plt # allow to specify mpl default with 'default' if figsize is None or figsize == 'default': @@ -1631,9 +1632,15 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, fig, axes = _subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax) - ravel_axes = [] - for row in axes: - ravel_axes.extend(row) + if isinstance(axes, plt.Axes): + ravel_axes = [axes] + else: + ravel_axes = [] + for row in axes: + if isinstance(row, plt.Axes): + ravel_axes.append(row) + else: + ravel_axes.extend(row) for i, (key, group) in enumerate(grouped): ax = ravel_axes[i]