Skip to content

Commit 89f9062

Browse files
authored
enh: Add __repr__ to reductions (#1418)
1 parent f8e83b9 commit 89f9062

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

datashader/reductions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,9 @@ def _create_int64(shape, array_module):
471471
def _create_uint32(shape, array_module):
472472
return array_module.zeros(shape, dtype='u4')
473473

474+
def __repr__(self):
475+
return f"{type(self).__name__}({self.column!r})"
476+
474477

475478
class OptionalFieldReduction(Reduction):
476479
"""Base class for things like ``count`` or ``any`` for which the field is optional"""
@@ -657,6 +660,9 @@ def _combine_antialias(aggs):
657660
nansum_in_place(ret, aggs[i])
658661
return ret
659662

663+
def __repr__(self):
664+
return "count()"
665+
660666

661667
class _count_ignore_antialiasing(count):
662668
"""Count reduction but ignores antialiasing. Used by mean reduction.
@@ -813,6 +819,9 @@ def finalize(bases, cuda=False, **kwargs):
813819

814820
return finalize
815821

822+
def __repr__(self):
823+
return f"{type(self).__name__}(column={self.column!r}, reduction={self.reduction!r})"
824+
816825
class any(OptionalFieldReduction):
817826
"""Whether any elements in ``column`` map to each bin.
818827
@@ -1265,6 +1274,8 @@ class count_cat(by):
12651274
def __init__(self, column):
12661275
super().__init__(column, count())
12671276

1277+
def __repr__(self):
1278+
return f"count_cat(column={self.column!r})"
12681279

12691280
class mean(Reduction):
12701281
"""Mean of all elements in ``column``.
@@ -1477,6 +1488,9 @@ def finalize(bases, cuda=False, **kwargs):
14771488
def _hashable_inputs(self):
14781489
return super()._hashable_inputs() + (self.n,)
14791490

1491+
def __repr__(self):
1492+
return f"{type(self).__name__}(column={self.column!r}, n={self.n!r})"
1493+
14801494

14811495
class _first_n_or_last_n(FloatingNReduction):
14821496
"""Abstract base class of first_n and last_n reductions.
@@ -2103,6 +2117,9 @@ def finalize(bases, cuda=False, **kwargs):
21032117

21042118
return finalize
21052119

2120+
def __repr__(self):
2121+
return f"where(selector={self.selector!r}, lookup_column={self.column!r})"
2122+
21062123

21072124
class summary(Expr):
21082125
"""A collection of named reductions.
@@ -2166,6 +2183,10 @@ def validate(self, input_dshape):
21662183
def inputs(self):
21672184
return tuple(unique(concat(v.inputs for v in self.values)))
21682185

2186+
def __repr__(self):
2187+
pairs = ", ".join([f"{k}={v!r}" for k, v in zip(self.keys, self.values, strict=True)])
2188+
return f"summary({pairs})"
2189+
21692190

21702191
class _max_or_min_row_index(OptionalFieldReduction):
21712192
"""Abstract base class of max and min row_index reductions.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import datashader as ds
2+
3+
4+
def all_subclasses(cls):
5+
items1 = {cls, *cls.__subclasses__()}
6+
items2 = {s for c in cls.__subclasses__() for s in all_subclasses(c)}
7+
return items1 | items2
8+
9+
10+
def test_string_output():
11+
expected = {
12+
"any": "any('col')",
13+
"by": "by(column='col', reduction=count())",
14+
"count": "count()",
15+
"count_cat": "count_cat(column='col')",
16+
"first": "first('col')",
17+
"first_n": "first_n(column='col', n=1)",
18+
"last": "last('col')",
19+
"last_n": "last_n(column='col', n=1)",
20+
"m2": "m2('col')",
21+
"max": "max('col')",
22+
"max_n": "max_n(column='col', n=1)",
23+
"mean": "mean('col')",
24+
"min": "min('col')",
25+
"min_n": "min_n(column='col', n=1)",
26+
"mode": "mode('col')",
27+
"std": "std('col')",
28+
"sum": "sum('col')",
29+
"summary": "summary(a=1)",
30+
"var": "var('col')",
31+
"where": "where(selector=min('col'), lookup_column='col')",
32+
}
33+
34+
count = 0
35+
for red in all_subclasses(ds.reductions.Reduction) | all_subclasses(ds.reductions.summary):
36+
red_name = red.__name__
37+
if red_name.startswith("_") or "Reduction" in red_name:
38+
continue
39+
elif red_name in ("by", "count_cat"):
40+
assert str(red("col")) == expected[red_name]
41+
elif red_name == "where":
42+
assert str(red(ds.min("col"), "col")) == expected[red_name]
43+
elif red_name == "summary":
44+
assert str(red(a=1)) == expected[red_name]
45+
else:
46+
assert str(red("col")) == expected[red_name]
47+
count += 1
48+
49+
assert count == 20 # Update if more subclasses are added

0 commit comments

Comments
 (0)