Skip to content

Commit be1330b

Browse files
committed
Implement Elemwise and Blockwise operations for XTensorVariables
1 parent cd1e5dc commit be1330b

File tree

9 files changed

+703
-0
lines changed

9 files changed

+703
-0
lines changed

pytensor/xtensor/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import warnings
22

33
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor import (
5+
linalg,
6+
)
47
from pytensor.xtensor.type import (
58
XTensorType,
69
as_xtensor,

pytensor/xtensor/linalg.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from collections.abc import Sequence
2+
from typing import Literal
3+
4+
from pytensor.tensor.slinalg import Cholesky, Solve
5+
from pytensor.xtensor.type import as_xtensor
6+
from pytensor.xtensor.vectorization import XBlockwise
7+
8+
9+
def cholesky(
10+
x,
11+
lower: bool = True,
12+
*,
13+
check_finite: bool = False,
14+
overwrite_a: bool = False,
15+
on_error: Literal["raise", "nan"] = "raise",
16+
dims: Sequence[str],
17+
):
18+
if len(dims) != 2:
19+
raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
20+
21+
core_op = Cholesky(
22+
lower=lower,
23+
check_finite=check_finite,
24+
overwrite_a=overwrite_a,
25+
on_error=on_error,
26+
)
27+
core_dims = (
28+
((dims[0], dims[1]),),
29+
((dims[0], dims[1]),),
30+
)
31+
x_op = XBlockwise(core_op, core_dims=core_dims)
32+
return x_op(x)
33+
34+
35+
def solve(
36+
a,
37+
b,
38+
dims: Sequence[str],
39+
assume_a="gen",
40+
lower: bool = False,
41+
check_finite: bool = False,
42+
):
43+
a, b = as_xtensor(a), as_xtensor(b)
44+
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
45+
output_core_dims: tuple[tuple[str] | tuple[str, str]]
46+
if len(dims) == 2:
47+
b_ndim = 1
48+
[m1_dim] = [dim for dim in dims if dim not in b.type.dims]
49+
m2_dim = dims[0] if dims[0] != m1_dim else dims[1]
50+
input_core_dims = ((m1_dim, m2_dim), (m2_dim,))
51+
# The shared dim disappears in the output
52+
output_core_dims = ((m1_dim,),)
53+
elif len(dims) == 3:
54+
b_ndim = 2
55+
[n_dim] = [dim for dim in dims if dim not in a.type.dims]
56+
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim]
57+
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim))
58+
# The shared dim disappears in the output
59+
output_core_dims = ((m1_dim, n_dim),)
60+
else:
61+
raise ValueError("Solve dims must have length 2 or 3")
62+
63+
core_op = Solve(
64+
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
65+
)
66+
x_op = XBlockwise(
67+
core_op,
68+
core_dims=(input_core_dims, output_core_dims),
69+
)
70+
return x_op(a, b)

pytensor/xtensor/math.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import sys
2+
3+
import pytensor.scalar as ps
4+
from pytensor.scalar import ScalarOp
5+
from pytensor.xtensor.vectorization import XElemwise
6+
7+
8+
this_module = sys.modules[__name__]
9+
10+
11+
def _as_xelemwise(core_op: ScalarOp) -> XElemwise:
12+
out = XElemwise(core_op)
13+
out.__doc__ = f"Ufunc version of {core_op} for XTensorVariables"
14+
return out
15+
16+
17+
abs = _as_xelemwise(ps.abs)
18+
add = _as_xelemwise(ps.add)
19+
logical_and = bitwise_and = and_ = _as_xelemwise(ps.and_)
20+
angle = _as_xelemwise(ps.angle)
21+
arccos = _as_xelemwise(ps.arccos)
22+
arccosh = _as_xelemwise(ps.arccosh)
23+
arcsin = _as_xelemwise(ps.arcsin)
24+
arcsinh = _as_xelemwise(ps.arcsinh)
25+
arctan = _as_xelemwise(ps.arctan)
26+
arctan2 = _as_xelemwise(ps.arctan2)
27+
arctanh = _as_xelemwise(ps.arctanh)
28+
betainc = _as_xelemwise(ps.betainc)
29+
betaincinv = _as_xelemwise(ps.betaincinv)
30+
ceil = _as_xelemwise(ps.ceil)
31+
clip = _as_xelemwise(ps.clip)
32+
complex = _as_xelemwise(ps.complex)
33+
conjugate = conj = _as_xelemwise(ps.conj)
34+
cos = _as_xelemwise(ps.cos)
35+
cosh = _as_xelemwise(ps.cosh)
36+
deg2rad = _as_xelemwise(ps.deg2rad)
37+
equal = eq = _as_xelemwise(ps.eq)
38+
erf = _as_xelemwise(ps.erf)
39+
erfc = _as_xelemwise(ps.erfc)
40+
erfcinv = _as_xelemwise(ps.erfcinv)
41+
erfcx = _as_xelemwise(ps.erfcx)
42+
erfinv = _as_xelemwise(ps.erfinv)
43+
exp = _as_xelemwise(ps.exp)
44+
exp2 = _as_xelemwise(ps.exp2)
45+
expm1 = _as_xelemwise(ps.expm1)
46+
floor = _as_xelemwise(ps.floor)
47+
floor_divide = floor_div = int_div = _as_xelemwise(ps.int_div)
48+
gamma = _as_xelemwise(ps.gamma)
49+
gammainc = _as_xelemwise(ps.gammainc)
50+
gammaincc = _as_xelemwise(ps.gammaincc)
51+
gammainccinv = _as_xelemwise(ps.gammainccinv)
52+
gammaincinv = _as_xelemwise(ps.gammaincinv)
53+
gammal = _as_xelemwise(ps.gammal)
54+
gammaln = _as_xelemwise(ps.gammaln)
55+
gammau = _as_xelemwise(ps.gammau)
56+
greater_equal = ge = _as_xelemwise(ps.ge)
57+
greater = gt = _as_xelemwise(ps.gt)
58+
hyp2f1 = _as_xelemwise(ps.hyp2f1)
59+
i0 = _as_xelemwise(ps.i0)
60+
i1 = _as_xelemwise(ps.i1)
61+
identity = _as_xelemwise(ps.identity)
62+
imag = _as_xelemwise(ps.imag)
63+
logical_not = bitwise_invert = bitwise_not = invert = _as_xelemwise(ps.invert)
64+
isinf = _as_xelemwise(ps.isinf)
65+
isnan = _as_xelemwise(ps.isnan)
66+
iv = _as_xelemwise(ps.iv)
67+
ive = _as_xelemwise(ps.ive)
68+
j0 = _as_xelemwise(ps.j0)
69+
j1 = _as_xelemwise(ps.j1)
70+
jv = _as_xelemwise(ps.jv)
71+
kve = _as_xelemwise(ps.kve)
72+
less_equal = le = _as_xelemwise(ps.le)
73+
log = _as_xelemwise(ps.log)
74+
log10 = _as_xelemwise(ps.log10)
75+
log1mexp = _as_xelemwise(ps.log1mexp)
76+
log1p = _as_xelemwise(ps.log1p)
77+
log2 = _as_xelemwise(ps.log2)
78+
less = lt = _as_xelemwise(ps.lt)
79+
mod = _as_xelemwise(ps.mod)
80+
multiply = mul = _as_xelemwise(ps.mul)
81+
negative = neg = _as_xelemwise(ps.neg)
82+
not_equal = neq = _as_xelemwise(ps.neq)
83+
logical_or = bitwise_or = or_ = _as_xelemwise(ps.or_)
84+
owens_t = _as_xelemwise(ps.owens_t)
85+
polygamma = _as_xelemwise(ps.polygamma)
86+
power = pow = _as_xelemwise(ps.pow)
87+
psi = _as_xelemwise(ps.psi)
88+
rad2deg = _as_xelemwise(ps.rad2deg)
89+
real = _as_xelemwise(ps.real)
90+
reciprocal = _as_xelemwise(ps.reciprocal)
91+
round = _as_xelemwise(ps.round_half_to_even)
92+
maximum = _as_xelemwise(ps.scalar_maximum)
93+
minimum = _as_xelemwise(ps.scalar_minimum)
94+
second = _as_xelemwise(ps.second)
95+
sigmoid = _as_xelemwise(ps.sigmoid)
96+
sign = _as_xelemwise(ps.sign)
97+
sin = _as_xelemwise(ps.sin)
98+
sinh = _as_xelemwise(ps.sinh)
99+
softplus = _as_xelemwise(ps.softplus)
100+
square = sqr = _as_xelemwise(ps.sqr)
101+
sqrt = _as_xelemwise(ps.sqrt)
102+
subtract = sub = _as_xelemwise(ps.sub)
103+
where = switch = _as_xelemwise(ps.switch)
104+
tan = _as_xelemwise(ps.tan)
105+
tanh = _as_xelemwise(ps.tanh)
106+
tri_gamma = _as_xelemwise(ps.tri_gamma)
107+
true_divide = true_div = _as_xelemwise(ps.true_div)
108+
trunc = _as_xelemwise(ps.trunc)
109+
logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
import pytensor.xtensor.rewriting.basic
22
import pytensor.xtensor.rewriting.shape
3+
import pytensor.xtensor.rewriting.vectorization
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor.blockwise import Blockwise
3+
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
5+
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
6+
from pytensor.xtensor.vectorization import XBlockwise, XElemwise
7+
8+
9+
@register_lower_xtensor
10+
@node_rewriter(tracks=[XElemwise])
11+
def lower_elemwise(fgraph, node):
12+
out_dims = node.outputs[0].type.dims
13+
14+
# Convert input XTensors to Tensors and align batch dimensions
15+
tensor_inputs = []
16+
for inp in node.inputs:
17+
inp_dims = inp.type.dims
18+
order = [
19+
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
20+
for out_dim in out_dims
21+
]
22+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
23+
tensor_inputs.append(tensor_inp)
24+
25+
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
26+
*tensor_inputs, return_list=True
27+
)
28+
29+
# Convert output Tensors to XTensors
30+
new_outs = [
31+
xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs
32+
]
33+
return new_outs
34+
35+
36+
@register_lower_xtensor
37+
@node_rewriter(tracks=[XBlockwise])
38+
def lower_blockwise(fgraph, node):
39+
op: XBlockwise = node.op
40+
batch_ndim = node.outputs[0].type.ndim - len(op.core_dims[1][0])
41+
batch_dims = node.outputs[0].type.dims[:batch_ndim]
42+
43+
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
44+
tensor_inputs = []
45+
for inp, core_dims in zip(node.inputs, op.core_dims[0]):
46+
inp_dims = inp.type.dims
47+
# Align the batch dims of the input, and place the core dims on the right
48+
batch_order = [
49+
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
50+
for batch_dim in batch_dims
51+
]
52+
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
53+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
54+
tensor_inputs.append(tensor_inp)
55+
56+
signature = op.signature or getattr(op.core_op, "gufunc_signature", None)
57+
if signature is None:
58+
# Build a signature based on the core dimensions
59+
# The Op signature could be more strict, as core_dims will never be repeated, but no functionality depends greatly on it
60+
inputs_core_dims, outputs_core_dims = op.core_dims
61+
inputs_signature = ",".join(
62+
f"({', '.join(inp_core_dims)})" for inp_core_dims in inputs_core_dims
63+
)
64+
outputs_signature = ",".join(
65+
f"({', '.join(out_core_dims)})" for out_core_dims in outputs_core_dims
66+
)
67+
signature = f"{inputs_signature}->{outputs_signature}"
68+
tensor_op = Blockwise(core_op=op.core_op, signature=signature)
69+
tensor_outs = tensor_op(*tensor_inputs, return_list=True)
70+
71+
# Convert output Tensors to XTensors
72+
new_outs = [
73+
xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
74+
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
75+
]
76+
return new_outs

0 commit comments

Comments
 (0)