Skip to content

Commit 155db9f

Browse files
committed
Implement basic labeled tensor functionality
1 parent f7cf273 commit 155db9f

File tree

12 files changed

+932
-2
lines changed

12 files changed

+932
-2
lines changed

.github/workflows/test.yml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ jobs:
8282
install-numba: [0]
8383
install-jax: [0]
8484
install-torch: [0]
85+
install-xarray: [0]
8586
part:
86-
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
87+
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse --ignore=tests/xtensor"
8788
- "tests/scan"
8889
- "tests/sparse"
89-
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py"
90+
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
9091
- "tests/tensor/conv"
9192
- "tests/tensor/rewriting"
9293
- "tests/tensor/test_math.py"
@@ -115,6 +116,7 @@ jobs:
115116
install-numba: 0
116117
install-jax: 0
117118
install-torch: 0
119+
install-xarray: 0
118120
- install-numba: 1
119121
os: "ubuntu-latest"
120122
python-version: "3.10"
@@ -150,6 +152,13 @@ jobs:
150152
fast-compile: 0
151153
float32: 0
152154
part: "tests/link/pytorch"
155+
- install-xarray: 1
156+
os: "ubuntu-latest"
157+
python-version: "3.13"
158+
numpy-version: ">=2.0"
159+
fast-compile: 0
160+
float32: 0
161+
part: "tests/xtensor"
153162
- os: macos-15
154163
python-version: "3.13"
155164
numpy-version: ">=2.0"
@@ -196,6 +205,7 @@ jobs:
196205
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
197206
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
198207
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
208+
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
199209
pip install pytest-sphinx
200210
201211
pip install -e ./
@@ -212,6 +222,7 @@ jobs:
212222
INSTALL_NUMBA: ${{ matrix.install-numba }}
213223
INSTALL_JAX: ${{ matrix.install-jax }}
214224
INSTALL_TORCH: ${{ matrix.install-torch}}
225+
INSTALL_XARRAY: ${{ matrix.install-xarray }}
215226
OS: ${{ matrix.os}}
216227

217228
- name: Run tests

pytensor/compile/mode.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def register_linker(name, linker):
6767
if not config.cxx:
6868
exclude = ["cxx_only"]
6969
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
70+
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
71+
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
7072
# Even if multiple merge optimizer call will be there, this shouldn't
7173
# impact performance.
7274
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
@@ -77,6 +79,7 @@ def register_linker(name, linker):
7779
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
7880
OPT_STABILIZE.position_cutoff = 1.5000001
7981
OPT_NONE.name = "OPT_NONE"
82+
OPT_MINIMUM.name = "OPT_MINIMUM"
8083
OPT_MERGE.name = "OPT_MERGE"
8184
OPT_FAST_RUN.name = "OPT_FAST_RUN"
8285
OPT_FAST_RUN_STABLE.name = "OPT_FAST_RUN_STABLE"
@@ -95,6 +98,7 @@ def register_linker(name, linker):
9598
None: OPT_NONE,
9699
"None": OPT_NONE,
97100
"merge": OPT_MERGE,
101+
"minimum_compile": OPT_MINIMUM,
98102
"o4": OPT_FAST_RUN,
99103
"o3": OPT_O3,
100104
"o2": OPT_O2,
@@ -191,6 +195,7 @@ def apply(self, fgraph):
191195
"merge1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0
192196
)
193197

198+
194199
# After scan1 opt at 0.5 and before ShapeOpt at 1
195200
# This should only remove nodes.
196201
# The opt should not do anything that need shape inference.

pytensor/xtensor/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
xtensor,
8+
xtensor_constant,
9+
)
10+
11+
12+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from collections.abc import Sequence
2+
3+
from pytensor.compile.ops import TypeCastingOp
4+
from pytensor.graph import Apply, Op
5+
from pytensor.tensor.type import TensorType
6+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
7+
8+
9+
class XOp(Op):
10+
"""A base class for XOps that shouldn't be materialized"""
11+
12+
def perform(self, node, inputs, outputs):
13+
raise NotImplementedError(
14+
f"xtensor operation {self} must be lowered to equivalent tensor operations"
15+
)
16+
17+
18+
class XTypeCastOp(TypeCastingOp):
19+
"""Base class for Ops that type cast between TensorType and XTensorType.
20+
21+
This is like a `ViewOp` but without the expectation the input and output have identical types.
22+
"""
23+
24+
25+
class TensorFromXTensor(XTypeCastOp):
26+
__props__ = ()
27+
28+
def make_node(self, x):
29+
if not isinstance(x.type, XTensorType):
30+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
31+
output = TensorType(x.type.dtype, shape=x.type.shape)()
32+
return Apply(self, [x], [output])
33+
34+
def L_op(self, inputs, outs, g_outs):
35+
[x] = inputs
36+
[g_out] = g_outs
37+
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
38+
39+
40+
tensor_from_xtensor = TensorFromXTensor()
41+
42+
43+
class XTensorFromTensor(XTypeCastOp):
44+
__props__ = ("dims",)
45+
46+
def __init__(self, dims: Sequence[str]):
47+
super().__init__()
48+
self.dims = tuple(dims)
49+
50+
def make_node(self, x):
51+
if not isinstance(x.type, TensorType):
52+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
53+
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
54+
return Apply(self, [x], [output])
55+
56+
def L_op(self, inputs, outs, g_outs):
57+
[g_out] = g_outs
58+
return [tensor_from_xtensor(g_out)]
59+
60+
61+
def xtensor_from_tensor(x, dims, name=None):
62+
return XTensorFromTensor(dims=dims)(x, name=name)
63+
64+
65+
class Rename(XTypeCastOp):
66+
__props__ = ("new_dims",)
67+
68+
def __init__(self, new_dims: tuple[str, ...]):
69+
super().__init__()
70+
self.new_dims = new_dims
71+
72+
def make_node(self, x):
73+
x = as_xtensor(x)
74+
output = x.type.clone(dims=self.new_dims)()
75+
return Apply(self, [x], [output])
76+
77+
def L_op(self, inputs, outs, g_outs):
78+
[x] = inputs
79+
[g_out] = g_outs
80+
return [rename(g_out, dims=x.type.dims)]
81+
82+
83+
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
84+
if name_dict is not None:
85+
if names:
86+
raise ValueError("Cannot use both positional and keyword names in rename")
87+
names = name_dict
88+
89+
x = as_xtensor(x)
90+
old_names = x.type.dims
91+
new_names = list(old_names)
92+
for old_name, new_name in names.items():
93+
try:
94+
new_names[old_names.index(old_name)] = new_name
95+
except ValueError:
96+
raise ValueError(
97+
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
98+
)
99+
100+
return Rename(tuple(new_names))(x)

pytensor/xtensor/readme.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# XTensor Module
2+
3+
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
4+
5+
A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute,
6+
that labels the dimensions of the tensor.
7+
8+
Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects.
9+
10+
The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations.
11+
These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into
12+
a regular tensor graph that can itself be evaluated as usual.
13+
14+
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
15+
If the existing XOps can be composed to produce the desired result, then we can use them directly.
16+
17+
## Coordinates
18+
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
19+
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
20+
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
21+
22+
## Example
23+
24+
```python
25+
import pytensor.tensor as pt
26+
import pytensor.xtensor as px
27+
28+
a = pt.tensor("a", shape=(3,))
29+
b = pt.tensor("b", shape=(4,))
30+
31+
ax = px.as_xtensor(a, dims=["x"])
32+
bx = px.as_xtensor(b, dims=["y"])
33+
34+
zx = ax + bx
35+
assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
36+
37+
z = zx.values
38+
z.dprint()
39+
# TensorFromXTensor [id A]
40+
# └─ XElemwise{scalar_op=Add()} [id B]
41+
# ├─ XTensorFromTensor{dims=('x',)} [id C]
42+
# │ └─ a [id D]
43+
# └─ XTensorFromTensor{dims=('y',)} [id E]
44+
# └─ b [id F]
45+
```
46+
47+
Once we compile the graph, no `XOp`s are left.
48+
49+
```python
50+
import pytensor
51+
52+
with pytensor.config.change_flags(optimizer_verbose=True):
53+
fn = pytensor.function([a, b], z)
54+
55+
# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
56+
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
57+
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
58+
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
59+
60+
fn.dprint()
61+
# Add [id A] 2
62+
# ├─ ExpandDims{axis=1} [id B] 1
63+
# │ └─ a [id C]
64+
# └─ ExpandDims{axis=0} [id D] 0
65+
# └─ b [id E]
66+
```
67+
68+
69+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.xtensor.rewriting.basic

pytensor/xtensor/rewriting/basic.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor.basic import register_infer_shape
3+
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
4+
from pytensor.xtensor.basic import (
5+
Rename,
6+
TensorFromXTensor,
7+
XTensorFromTensor,
8+
xtensor_from_tensor,
9+
)
10+
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
11+
12+
13+
@register_infer_shape
14+
@register_useless
15+
@register_canonicalize
16+
@register_lower_xtensor
17+
@node_rewriter(tracks=[TensorFromXTensor])
18+
def useless_tensor_from_xtensor(fgraph, node):
19+
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
20+
[x] = node.inputs
21+
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
22+
return [x.owner.inputs[0]]
23+
24+
25+
@register_infer_shape
26+
@register_useless
27+
@register_canonicalize
28+
@register_lower_xtensor
29+
@node_rewriter(tracks=[XTensorFromTensor])
30+
def useless_xtensor_from_tensor(fgraph, node):
31+
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
32+
[x] = node.inputs
33+
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
34+
return [x.owner.inputs[0]]
35+
36+
37+
@register_lower_xtensor
38+
@node_rewriter(tracks=[TensorFromXTensor])
39+
def useless_tensor_from_xtensor_of_rename(fgraph, node):
40+
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
41+
[renamed_x] = node.inputs
42+
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
43+
[x] = renamed_x.owner.inputs
44+
return node.op(x, return_list=True)
45+
46+
47+
@register_lower_xtensor
48+
@node_rewriter(tracks=[Rename])
49+
def useless_rename(fgraph, node):
50+
"""
51+
52+
Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims)
53+
Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims)
54+
"""
55+
[renamed_x] = node.inputs
56+
if renamed_x.owner:
57+
if isinstance(renamed_x.owner.op, Rename):
58+
[x] = renamed_x.owner.inputs
59+
return [node.op(x)]
60+
elif isinstance(renamed_x.owner.op, TensorFromXTensor):
61+
[x] = renamed_x.owner.inputs
62+
return [xtensor_from_tensor(x, dims=node.op.new_dims)]

pytensor/xtensor/rewriting/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import NodeRewriter
3+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
4+
5+
6+
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
7+
8+
optdb.register(
9+
"lower_xtensor",
10+
lower_xtensor_db,
11+
"fast_run",
12+
"fast_compile",
13+
"minimum_compile",
14+
position=0.1,
15+
)
16+
17+
18+
def register_lower_xtensor(
19+
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
20+
):
21+
if isinstance(node_rewriter, str):
22+
23+
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
24+
return register_lower_xtensor(
25+
inner_rewriter, node_rewriter, *tags, **kwargs
26+
)
27+
28+
return register
29+
30+
else:
31+
name = kwargs.pop("name", None) or node_rewriter.__name__ # type: ignore
32+
lower_xtensor_db.register(
33+
name,
34+
node_rewriter,
35+
"fast_run",
36+
"fast_compile",
37+
"minimum_compile",
38+
*tags,
39+
**kwargs,
40+
)
41+
return node_rewriter

0 commit comments

Comments
 (0)