Skip to content

Commit 8264a71

Browse files
committed
Add ability to specify which vars to aggregate
1 parent 3ef2a0f commit 8264a71

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

seaborn/_core/plot.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,19 +1185,20 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
11851185

11861186
def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:
11871187

1188-
grouping_vars = [v for v in PROPERTIES if v not in "xy"]
1189-
grouping_vars += ["col", "row", "group"]
11901188

11911189
pair_vars = spec._pair_spec.get("structure", {})
11921190

11931191
for layer in layers:
1194-
11951192
data = layer["data"]
11961193
mark = layer["mark"]
11971194
stat = layer["stat"]
11981195

11991196
if stat is None:
12001197
continue
1198+
target_vars = getattr(stat, "target_vars", "xy")
1199+
1200+
grouping_vars = [v for v in PROPERTIES if v not in target_vars]
1201+
grouping_vars += ["col", "row", "group"]
12011202

12021203
iter_axes = itertools.product(*[
12031204
pair_vars.get(axis, [axis]) for axis in "xy"

seaborn/_stats/aggregation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
from dataclasses import dataclass
3-
from typing import ClassVar, Callable
3+
from typing import ClassVar, Callable, Iterable
44

55
import pandas as pd
66
from pandas import DataFrame
@@ -21,6 +21,9 @@ class Agg(Stat):
2121
----------
2222
func : str or callable
2323
Name of a :class:`pandas.Series` method or a vector -> scalar function.
24+
target_vars : list of strings
25+
Variables to perform the aggregation on. Defaults to x or y, depending on
26+
orientation.
2427
2528
See Also
2629
--------
@@ -32,18 +35,19 @@ class Agg(Stat):
3235
3336
"""
3437
func: str | Callable[[Vector], float] = "mean"
38+
target_vars: Iterable[str] = ("x", "y")
3539

3640
group_by_orient: ClassVar[bool] = True
3741

3842
def __call__(
3943
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
4044
) -> DataFrame:
4145

42-
var = {"x": "y", "y": "x"}.get(orient)
46+
vars = [v for v in self.target_vars if v != orient]
4347
res = (
4448
groupby
45-
.agg(data, {var: self.func})
46-
.dropna(subset=[var])
49+
.agg(data, {var: self.func for var in vars})
50+
.dropna(subset=vars)
4751
.reset_index(drop=True)
4852
)
4953
return res

0 commit comments

Comments
 (0)