diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index b33daf39de37c..01fe98a6f5403 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -1,4 +1,5 @@ from collections import namedtuple +from typing import TYPE_CHECKING import warnings from matplotlib.artist import setp @@ -14,6 +15,9 @@ from pandas.plotting._matplotlib.style import _get_standard_colors from pandas.plotting._matplotlib.tools import _flatten, _subplots +if TYPE_CHECKING: + from matplotlib.axes import Axes + class BoxPlot(LinePlot): _kind = "box" @@ -150,7 +154,7 @@ def _make_plot(self): labels = [pprint_thing(key) for key in range(len(labels))] self._set_ticklabels(ax, labels) - def _set_ticklabels(self, ax, labels): + def _set_ticklabels(self, ax: "Axes", labels): if self.orientation == "vertical": ax.set_xticklabels(labels) else: @@ -292,7 +296,7 @@ def maybe_color_bp(bp, **kwds): if not kwds.get("capprops"): setp(bp["caps"], color=colors[3], alpha=1) - def plot_group(keys, values, ax): + def plot_group(keys, values, ax: "Axes"): keys = [pprint_thing(x) for x in keys] values = [np.asarray(remove_na_arraylike(v), dtype=object) for v in values] bp = ax.boxplot(values, **kwds) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 4d23a5e5fc249..93ba9bd26630b 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -1,5 +1,5 @@ import re -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple import warnings from matplotlib.artist import Artist @@ -45,6 +45,7 @@ if TYPE_CHECKING: from matplotlib.axes import Axes + from matplotlib.axis import Axis class MPLPlot: @@ -68,16 +69,10 @@ def _kind(self): _pop_attributes = [ "label", "style", - "logy", - "logx", - "loglog", "mark_right", "stacked", ] _attr_defaults = { - "logy": False, - "logx": False, - "loglog": False, "mark_right": True, "stacked": False, } @@ -167,6 +162,9 @@ def __init__( self.legend_handles: List[Artist] = [] self.legend_labels: List[Label] = [] + self.logx = kwds.pop("logx", False) + self.logy = kwds.pop("logy", False) + self.loglog = kwds.pop("loglog", False) for attr in self._pop_attributes: value = kwds.pop(attr, self._attr_defaults.get(attr, None)) setattr(self, attr, value) @@ -283,11 +281,11 @@ def generate(self): def _args_adjust(self): pass - def _has_plotted_object(self, ax): + def _has_plotted_object(self, ax: "Axes") -> bool: """check whether ax has data""" return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0 - def _maybe_right_yaxis(self, ax, axes_num): + def _maybe_right_yaxis(self, ax: "Axes", axes_num): if not self.on_right(axes_num): # secondary axes may be passed via ax kw return self._get_ax_layer(ax) @@ -523,7 +521,7 @@ def _adorn_subplots(self): raise ValueError(msg) self.axes[0].set_title(self.title) - def _apply_axis_properties(self, axis, rot=None, fontsize=None): + def _apply_axis_properties(self, axis: "Axis", rot=None, fontsize=None): """ Tick creation within matplotlib is reasonably expensive and is internally deferred until accessed as Ticks are created/destroyed @@ -540,7 +538,7 @@ def _apply_axis_properties(self, axis, rot=None, fontsize=None): label.set_fontsize(fontsize) @property - def legend_title(self): + def legend_title(self) -> Optional[str]: if not isinstance(self.data.columns, ABCMultiIndex): name = self.data.columns.name if name is not None: @@ -591,7 +589,7 @@ def _make_legend(self): if ax.get_visible(): ax.legend(loc="best") - def _get_ax_legend_handle(self, ax): + def _get_ax_legend_handle(self, ax: "Axes"): """ Take in axes and return ax, legend and handle under different scenarios """ @@ -616,7 +614,7 @@ def plt(self): _need_to_set_index = False - def _get_xticks(self, convert_period=False): + def _get_xticks(self, convert_period: bool = False): index = self.data.index is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time") @@ -646,7 +644,7 @@ def _get_xticks(self, convert_period=False): @classmethod @register_pandas_matplotlib_converters - def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds): + def _plot(cls, ax: "Axes", x, y, style=None, is_errorbar: bool = False, **kwds): mask = isna(y) if mask.any(): y = np.ma.array(y) @@ -667,10 +665,10 @@ def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds): if style is not None: args = (x, y, style) else: - args = (x, y) + args = (x, y) # type:ignore[assignment] return ax.plot(*args, **kwds) - def _get_index_name(self): + def _get_index_name(self) -> Optional[str]: if isinstance(self.data.index, ABCMultiIndex): name = self.data.index.names if com.any_not_none(*name): @@ -877,7 +875,7 @@ def _get_subplots(self): ax for ax in self.axes[0].get_figure().get_axes() if isinstance(ax, Subplot) ] - def _get_axes_layout(self): + def _get_axes_layout(self) -> Tuple[int, int]: axes = self._get_subplots() x_set = set() y_set = set() @@ -916,15 +914,15 @@ def __init__(self, data, x, y, **kwargs): self.y = y @property - def nseries(self): + def nseries(self) -> int: return 1 - def _post_plot_logic(self, ax, data): + def _post_plot_logic(self, ax: "Axes", data): x, y = self.x, self.y ax.set_ylabel(pprint_thing(y)) ax.set_xlabel(pprint_thing(x)) - def _plot_colorbar(self, ax, **kwds): + def _plot_colorbar(self, ax: "Axes", **kwds): # Addresses issues #10611 and #10678: # When plotting scatterplots and hexbinplots in IPython # inline backend the colorbar axis height tends not to @@ -1080,7 +1078,7 @@ def __init__(self, data, **kwargs): if "x_compat" in self.kwds: self.x_compat = bool(self.kwds.pop("x_compat")) - def _is_ts_plot(self): + def _is_ts_plot(self) -> bool: # this is slightly deceptive return not self.x_compat and self.use_index and self._use_dynamic_x() @@ -1139,7 +1137,9 @@ def _make_plot(self): ax.set_xlim(left, right) @classmethod - def _plot(cls, ax, x, y, style=None, column_num=None, stacking_id=None, **kwds): + def _plot( + cls, ax: "Axes", x, y, style=None, column_num=None, stacking_id=None, **kwds + ): # column_num is used to get the target column from plotf in line and # area plots if column_num == 0: @@ -1183,7 +1183,7 @@ def _get_stacking_id(self): return None @classmethod - def _initialize_stacker(cls, ax, stacking_id, n): + def _initialize_stacker(cls, ax: "Axes", stacking_id, n: int): if stacking_id is None: return if not hasattr(ax, "_stacker_pos_prior"): @@ -1194,7 +1194,7 @@ def _initialize_stacker(cls, ax, stacking_id, n): ax._stacker_neg_prior[stacking_id] = np.zeros(n) @classmethod - def _get_stacked_values(cls, ax, stacking_id, values, label): + def _get_stacked_values(cls, ax: "Axes", stacking_id, values, label): if stacking_id is None: return values if not hasattr(ax, "_stacker_pos_prior"): @@ -1213,7 +1213,7 @@ def _get_stacked_values(cls, ax, stacking_id, values, label): ) @classmethod - def _update_stacker(cls, ax, stacking_id, values): + def _update_stacker(cls, ax: "Axes", stacking_id, values): if stacking_id is None: return if (values >= 0).all(): @@ -1221,7 +1221,7 @@ def _update_stacker(cls, ax, stacking_id, values): elif (values <= 0).all(): ax._stacker_neg_prior[stacking_id] += values - def _post_plot_logic(self, ax, data): + def _post_plot_logic(self, ax: "Axes", data): from matplotlib.ticker import FixedLocator def get_label(i): @@ -1276,7 +1276,7 @@ def __init__(self, data, **kwargs): @classmethod def _plot( cls, - ax, + ax: "Axes", x, y, style=None, @@ -1318,7 +1318,7 @@ def _plot( res = [rect] return res - def _post_plot_logic(self, ax, data): + def _post_plot_logic(self, ax: "Axes", data): LinePlot._post_plot_logic(self, ax, data) if self.ylim is None: @@ -1372,7 +1372,7 @@ def _args_adjust(self): self.left = np.array(self.left) @classmethod - def _plot(cls, ax, x, y, w, start=0, log=False, **kwds): + def _plot(cls, ax: "Axes", x, y, w, start=0, log=False, **kwds): return ax.bar(x, y, w, bottom=start, log=log, **kwds) @property @@ -1454,7 +1454,7 @@ def _make_plot(self): ) self._add_legend_handle(rect, label, index=i) - def _post_plot_logic(self, ax, data): + def _post_plot_logic(self, ax: "Axes", data): if self.use_index: str_index = [pprint_thing(key) for key in data.index] else: @@ -1466,7 +1466,7 @@ def _post_plot_logic(self, ax, data): self._decorate_ticks(ax, name, str_index, s_edge, e_edge) - def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge): + def _decorate_ticks(self, ax: "Axes", name, ticklabels, start_edge, end_edge): ax.set_xlim((start_edge, end_edge)) if self.xticks is not None: @@ -1489,10 +1489,10 @@ def _start_base(self): return self.left @classmethod - def _plot(cls, ax, x, y, w, start=0, log=False, **kwds): + def _plot(cls, ax: "Axes", x, y, w, start=0, log=False, **kwds): return ax.barh(x, y, w, left=start, log=log, **kwds) - def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge): + def _decorate_ticks(self, ax: "Axes", name, ticklabels, start_edge, end_edge): # horizontal bars ax.set_ylim((start_edge, end_edge)) ax.set_yticks(self.tick_pos) diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index ee41479b3c7c9..ffd46d1b191db 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import numpy as np from pandas.core.dtypes.common import is_integer, is_list_like @@ -8,6 +10,9 @@ from pandas.plotting._matplotlib.core import LinePlot, MPLPlot from pandas.plotting._matplotlib.tools import _flatten, _set_ticks_props, _subplots +if TYPE_CHECKING: + from matplotlib.axes import Axes + class HistPlot(LinePlot): _kind = "hist" @@ -90,7 +95,7 @@ def _make_plot_keywords(self, kwds, y): kwds["bins"] = self.bins return kwds - def _post_plot_logic(self, ax, data): + def _post_plot_logic(self, ax: "Axes", data): if self.orientation == "horizontal": ax.set_xlabel("Frequency") else: