Skip to content

Commit 010e0f9

Browse files
committed
Implement reduction operations for XTensorVariables
1 parent cdb026f commit 010e0f9

File tree

8 files changed

+266
-21
lines changed

8 files changed

+266
-21
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -473,24 +473,6 @@ def cumprod(x, axis=None):
473473
return CumOp(axis=axis, mode="mul")(x)
474474

475475

476-
class CumsumOp(Op):
477-
__props__ = ("axis",)
478-
479-
def __new__(typ, *args, **kwargs):
480-
obj = object.__new__(CumOp, *args, **kwargs)
481-
obj.mode = "add"
482-
return obj
483-
484-
485-
class CumprodOp(Op):
486-
__props__ = ("axis",)
487-
488-
def __new__(typ, *args, **kwargs):
489-
obj = object.__new__(CumOp, *args, **kwargs)
490-
obj.mode = "mul"
491-
return obj
492-
493-
494476
def diff(x, n=1, axis=-1):
495477
"""Calculate the `n`-th order discrete difference along the given `axis`.
496478

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import warnings
22

33
import pytensor.xtensor.rewriting
4-
from pytensor.xtensor import (
5-
linalg,
6-
)
4+
from pytensor.xtensor import linalg
75
from pytensor.xtensor.type import (
86
XTensorType,
97
as_xtensor,

pytensor/xtensor/math.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,8 @@ def cast(x, dtype):
134134
if dtype not in _xelemwise_cast_op:
135135
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
136136
return _xelemwise_cast_op[dtype](x)
137+
138+
139+
def softmax(x, dim=None):
140+
exp_x = exp(x)
141+
return exp_x / exp_x.sum(dim=dim)

pytensor/xtensor/reduction.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import typing
2+
from collections.abc import Sequence
3+
from functools import partial
4+
from types import EllipsisType
5+
6+
import pytensor.scalar as ps
7+
from pytensor.graph.basic import Apply
8+
from pytensor.tensor.math import variadic_mul
9+
from pytensor.xtensor.basic import XOp
10+
from pytensor.xtensor.math import neq, sqrt
11+
from pytensor.xtensor.math import sqr as square
12+
from pytensor.xtensor.type import as_xtensor, xtensor
13+
14+
15+
REDUCE_DIM = str | Sequence[str] | EllipsisType | None
16+
17+
18+
class XReduce(XOp):
19+
__slots__ = ("binary_op", "dims")
20+
21+
def __init__(self, binary_op, dims: Sequence[str]):
22+
super().__init__()
23+
self.binary_op = binary_op
24+
# Order of reduce dims doesn't change the behavior of the Op
25+
self.dims = tuple(sorted(dims))
26+
27+
def make_node(self, x):
28+
x = as_xtensor(x)
29+
x_dims = x.type.dims
30+
x_dims_set = set(x_dims)
31+
reduce_dims_set = set(self.dims)
32+
if x_dims_set == reduce_dims_set:
33+
out_dims, out_shape = [], []
34+
else:
35+
if not reduce_dims_set.issubset(x_dims_set):
36+
raise ValueError(
37+
f"Reduced dims {self.dims} not found in array dimensions {x_dims}."
38+
)
39+
out_dims, out_shape = zip(
40+
*[
41+
(d, s)
42+
for d, s in zip(x_dims, x.type.shape)
43+
if d not in reduce_dims_set
44+
]
45+
)
46+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
47+
return Apply(self, [x], [output])
48+
49+
50+
def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]:
51+
if isinstance(dim, str):
52+
return (dim,)
53+
elif dim is None or dim is Ellipsis:
54+
x = as_xtensor(x)
55+
return typing.cast(tuple[str], x.type.dims)
56+
return dim
57+
58+
59+
def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
60+
dims = _process_user_dims(x, dim)
61+
return XReduce(binary_op=binary_op, dims=dims)(x)
62+
63+
64+
sum = partial(reduce, binary_op=ps.add)
65+
prod = partial(reduce, binary_op=ps.mul)
66+
max = partial(reduce, binary_op=ps.scalar_maximum)
67+
min = partial(reduce, binary_op=ps.scalar_minimum)
68+
69+
70+
def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):
71+
x = as_xtensor(x)
72+
if x.type.dtype != "bool":
73+
x = neq(x, 0)
74+
return reduce(x, dim=dim, binary_op=binary_op)
75+
76+
77+
all = partial(bool_reduce, binary_op=ps.and_)
78+
any = partial(bool_reduce, binary_op=ps.or_)
79+
80+
81+
def _infer_reduced_size(original_var, reduced_var):
82+
reduced_dims = reduced_var.dims
83+
return variadic_mul(
84+
*[size for dim, size in original_var.sizes if dim not in reduced_dims]
85+
)
86+
87+
88+
def mean(x, dim: REDUCE_DIM):
89+
x = as_xtensor(x)
90+
sum_x = sum(x, dim)
91+
n = _infer_reduced_size(x, sum_x)
92+
return sum_x / n
93+
94+
95+
def var(x, dim: REDUCE_DIM, *, ddof: int = 0):
96+
x = as_xtensor(x)
97+
x_mean = mean(x, dim)
98+
n = _infer_reduced_size(x, x_mean)
99+
return square(x - x_mean) / (n - ddof)
100+
101+
102+
def std(x, dim: REDUCE_DIM, *, ddof: int = 0):
103+
return sqrt(var(x, dim, ddof=ddof))
104+
105+
106+
class XCumReduce(XOp):
107+
__props__ = ("binary_op", "dims")
108+
109+
def __init__(self, binary_op, dims: Sequence[str]):
110+
self.binary_op = binary_op
111+
self.dims = tuple(sorted(dims)) # Order doesn't matter
112+
113+
def make_node(self, x):
114+
x = as_xtensor(x)
115+
out = x.type()
116+
return Apply(self, [x], [out])
117+
118+
119+
def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
120+
dims = _process_user_dims(x, dim)
121+
return XCumReduce(dims=dims, binary_op=binary_op)(x)
122+
123+
124+
cumsum = partial(cumreduce, binary_op=ps.add)
125+
cumprod = partial(cumreduce, binary_op=ps.mul)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.reduction
23
import pytensor.xtensor.rewriting.shape
34
import pytensor.xtensor.rewriting.vectorization
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from functools import partial
2+
3+
import pytensor.scalar as ps
4+
from pytensor.graph.rewriting.basic import node_rewriter
5+
from pytensor.tensor.extra_ops import CumOp
6+
from pytensor.tensor.math import All, Any, CAReduce, Max, Min, Prod, Sum
7+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
8+
from pytensor.xtensor.reduction import XCumReduce, XReduce
9+
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
10+
11+
12+
@register_lower_xtensor
13+
@node_rewriter(tracks=[XReduce])
14+
def lower_reduce(fgraph, node):
15+
[x] = node.inputs
16+
[out] = node.outputs
17+
x_dims = x.type.dims
18+
reduce_dims = node.op.dims
19+
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
20+
21+
if not reduce_axis:
22+
return [x]
23+
24+
match node.op.binary_op:
25+
case ps.add:
26+
tensor_op_class = Sum
27+
case ps.mul:
28+
tensor_op_class = Prod
29+
case ps.and_:
30+
tensor_op_class = All
31+
case ps.or_:
32+
tensor_op_class = Any
33+
case ps.scalar_maximum:
34+
tensor_op_class = Max
35+
case ps.scalar_minimum:
36+
tensor_op_class = Min
37+
case _:
38+
# Case without known/predefined Ops
39+
tensor_op_class = partial(CAReduce, scalar_op=node.op.binary_op)
40+
41+
x_tensor = tensor_from_xtensor(x)
42+
out_tensor = tensor_op_class(axis=reduce_axis)(x_tensor)
43+
new_out = xtensor_from_tensor(out_tensor, out.type.dims)
44+
return [new_out]
45+
46+
47+
@register_lower_xtensor
48+
@node_rewriter(tracks=[XCumReduce])
49+
def lower_cumreduce(fgraph, node):
50+
[x] = node.inputs
51+
x_dims = x.type.dims
52+
reduce_dims = node.op.dims
53+
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
54+
55+
if not reduce_axis:
56+
return [x]
57+
58+
match node.op.binary_op:
59+
case ps.add:
60+
tensor_op_class = partial(CumOp, mode="add")
61+
case ps.mul:
62+
tensor_op_class = partial(CumOp, mode="mul")
63+
case _:
64+
# We don't know how to convert an arbitrary binary cum/reduce Op
65+
return None
66+
67+
# Each dim corresponds to an application of Cumsum/Cumprod
68+
out_tensor = tensor_from_xtensor(x)
69+
for axis in reduce_axis:
70+
out_tensor = tensor_op_class(axis=axis)(out_tensor)
71+
out = xtensor_from_tensor(out_tensor, x.type.dims)
72+
return [out]

pytensor/xtensor/type.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,41 @@ def imag(self):
438438
def real(self):
439439
return px.math.real(self)
440440

441+
# Aggregation
442+
# https://docs.xarray.dev/en/latest/api.html#id6
443+
def all(self, dim=None):
444+
return px.reduction.all(self, dim)
445+
446+
def any(self, dim=None):
447+
return px.reduction.any(self, dim)
448+
449+
def max(self, dim=None):
450+
return px.reduction.max(self, dim)
451+
452+
def min(self, dim=None):
453+
return px.reduction.min(self, dim)
454+
455+
def mean(self, dim=None):
456+
return px.reduction.mean(self, dim)
457+
458+
def prod(self, dim=None):
459+
return px.reduction.prod(self, dim)
460+
461+
def sum(self, dim=None):
462+
return px.reduction.sum(self, dim)
463+
464+
def std(self, dim=None):
465+
return px.reduction.std(self, dim)
466+
467+
def var(self, dim=None):
468+
return px.reduction.var(self, dim)
469+
470+
def cumsum(self, dim=None):
471+
return px.reduction.cumsum(self, dim)
472+
473+
def cumprod(self, dim=None):
474+
return px.reduction.cumprod(self, dim)
475+
441476
# Reshaping and reorganizing
442477
# https://docs.xarray.dev/en/latest/api.html#id8
443478
def stack(self, dim, **dims):

tests/xtensor/test_reduction.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# ruff: noqa: E402
2+
import pytest
3+
4+
5+
pytest.importorskip("xarray")
6+
7+
from pytensor.xtensor.type import xtensor
8+
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
9+
10+
11+
@pytest.mark.parametrize(
12+
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
13+
)
14+
@pytest.mark.parametrize(
15+
"method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:]
16+
)
17+
def test_reduction(method, dim):
18+
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
19+
out = getattr(x, method)(dim=dim)
20+
21+
fn = xr_function([x], out)
22+
x_test = xr_arange_like(x)
23+
24+
xr_assert_allclose(
25+
fn(x_test),
26+
getattr(x_test, method)(dim=dim),
27+
)

0 commit comments

Comments
 (0)