Skip to content

Commit 2cfc0ba

Browse files
committed
ENH: use LineCollection for wide DataFrame.line plots
1 parent 0febdd9 commit 2cfc0ba

File tree

1 file changed

+68
-60
lines changed

1 file changed

+68
-60
lines changed

pandas/plotting/_matplotlib/core.py

Lines changed: 68 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,7 @@ def _make_legend(self) -> None:
15381538

15391539
class LinePlot(MPLPlot):
15401540
_default_rot = 0
1541+
_wide_line_threshold: int = 200
15411542

15421543
@property
15431544
def orientation(self) -> PlottingOrientation:
@@ -1547,93 +1548,100 @@ def orientation(self) -> PlottingOrientation:
15471548
def _kind(self) -> Literal["line", "area", "hist", "kde", "box"]:
15481549
return "line"
15491550

1550-
def __init__(self, data, **kwargs) -> None:
1551-
MPLPlot.__init__(self, data, **kwargs)
1551+
def __init__(self, data, **kwargs):
1552+
super().__init__(data, **kwargs)
15521553
if self.stacked:
15531554
self.data = self.data.fillna(value=0)
15541555

1555-
def _make_plot(self, fig: Figure) -> None:
1556-
threshold = 200 # switch when DataFrame has more than this many columns
1557-
can_use_lc = (
1558-
not self._is_ts_plot() # not a TS plot
1559-
and not self.stacked # stacking not requested
1560-
and not com.any_not_none(*self.errors.values()) # no error bars
1561-
and len(self.data.columns) > threshold
1556+
def _make_plot(self, fig):
1557+
is_ts = self._is_ts_plot()
1558+
1559+
use_lc = (
1560+
not is_ts
1561+
and not self.stacked
1562+
and not com.any_not_none(*self.errors.values())
1563+
and len(self.data.columns) > self._wide_line_threshold
15621564
)
1563-
if can_use_lc:
1564-
ax = self._get_ax(0)
1565-
x = self._get_xticks()
1566-
segments = [
1567-
np.column_stack((x, self.data[col].values)) for col in self.data.columns
1568-
]
1569-
base_colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
1570-
colors = list(itertools.islice(itertools.cycle(base_colors), len(segments)))
1571-
lc = LineCollection(
1572-
segments,
1573-
colors=colors,
1574-
linewidths=self.kwds.get("linewidth", mpl.rcParams["lines.linewidth"]),
1575-
)
1576-
ax.add_collection(lc)
1577-
ax.margins(0.05)
1578-
return # skip the per-column Line2D loop
15791565

1580-
if self._is_ts_plot():
1566+
if is_ts:
15811567
data = maybe_convert_index(self._get_ax(0), self.data)
1582-
1583-
x = data.index # dummy, not used
1568+
x_vals = data.index # (ignored by _ts_plot)
15841569
plotf = self._ts_plot
1585-
it = data.items()
1570+
iterator: Iterable = data.items()
1571+
else:
1572+
x_vals = self._get_xticks()
1573+
plotf = self._plot
1574+
iterator = self._iter_data(self.data)
1575+
1576+
# drawing step
1577+
if use_lc:
1578+
self._draw_with_linecollection(x_vals, iterator)
15861579
else:
1587-
x = self._get_xticks()
1588-
# error: Incompatible types in assignment (expression has type
1589-
# "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
1590-
# type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
1591-
plotf = self._plot # type: ignore[assignment]
1592-
# error: Incompatible types in assignment (expression has type
1593-
# "Iterator[tuple[Hashable, ndarray[Any, Any]]]", variable has
1594-
# type "Iterable[tuple[Hashable, Series]]")
1595-
it = self._iter_data(data=self.data) # type: ignore[assignment]
1580+
self._draw_iteratively(x_vals, plotf, iterator, is_ts=is_ts)
1581+
1582+
self._post_plot_logic(self._get_ax(0), self.data)
1583+
1584+
# fast path: one LineCollection
1585+
def _draw_with_linecollection(self, x_vals, iterator):
1586+
ax: Axes = self._get_ax(0)
1587+
n_rows = len(x_vals)
1588+
n_cols = len(self.data.columns)
1589+
1590+
# vertices: vectorised (n_cols, n_rows, 2)
1591+
x_vec = np.asarray(x_vals, dtype=float)
1592+
x2d = np.broadcast_to(x_vec[:, None], (n_rows, n_cols))
1593+
seg_array = np.dstack((x2d, self.data.values)).transpose(1, 0, 2)
1594+
1595+
# colours
1596+
base = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
1597+
colours = list(itertools.islice(itertools.cycle(base), n_cols))
15961598

1599+
# legend handling
1600+
if self.legend:
1601+
# Build proxy Line2D handles once (keeps API parity and tests).
1602+
proxy_handles = [
1603+
mpl.lines.Line2D([], [], color=colours[i], label=pprint_thing(lbl))
1604+
for i, (lbl, _) in enumerate(iterator)
1605+
]
1606+
for h in proxy_handles:
1607+
self._append_legend_handles_labels(h, h.get_label())
1608+
1609+
lc = LineCollection(
1610+
seg_array,
1611+
colors=colours,
1612+
linewidths=self.kwds.get("linewidth", mpl.rcParams["lines.linewidth"]),
1613+
)
1614+
ax.add_collection(lc)
1615+
ax.margins(0.05)
1616+
1617+
def _draw_iteratively(self, x_vals, plotf, iterator, *, is_ts: bool):
15971618
stacking_id = self._get_stacking_id()
15981619
is_errorbar = com.any_not_none(*self.errors.values())
1620+
colours = self._get_colors()
15991621

1600-
colors = self._get_colors()
1601-
for i, (label, y) in enumerate(it):
1602-
ax = self._get_ax(i)
1622+
for i, (label, y) in enumerate(iterator):
1623+
ax: Axes = self._get_ax(i)
16031624
kwds = self.kwds.copy()
16041625
if self.color is not None:
16051626
kwds["color"] = self.color
1606-
style, kwds = self._apply_style_colors(
1607-
colors,
1608-
kwds,
1609-
i,
1610-
# error: Argument 4 to "_apply_style_colors" of "MPLPlot" has
1611-
# incompatible type "Hashable"; expected "str"
1612-
label, # type: ignore[arg-type]
1613-
)
16141627

1615-
errors = self._get_errorbars(label=label, index=i)
1616-
kwds = dict(kwds, **errors)
1617-
1618-
label = pprint_thing(label)
1619-
label = self._mark_right_label(label, index=i)
1620-
kwds["label"] = label
1628+
style, kwds = self._apply_style_colors(colours, kwds, i, label)
1629+
kwds.update(self._get_errorbars(label, i))
1630+
kwds["label"] = self._mark_right_label(pprint_thing(label), index=i)
16211631

16221632
newlines = plotf(
16231633
ax,
1624-
x,
1634+
x_vals,
16251635
y,
16261636
style=style,
16271637
column_num=i,
16281638
stacking_id=stacking_id,
16291639
is_errorbar=is_errorbar,
16301640
**kwds,
16311641
)
1632-
self._append_legend_handles_labels(newlines[0], label)
1642+
self._append_legend_handles_labels(newlines[0], kwds["label"])
16331643

1634-
if self._is_ts_plot():
1635-
# reset of xlim should be used for ts data
1636-
# TODO: GH28021, should find a way to change view limit on xaxis
1644+
if is_ts:
16371645
lines = get_all_lines(ax)
16381646
left, right = get_xlim(lines)
16391647
ax.set_xlim(left, right)

0 commit comments

Comments
 (0)