Skip to content

Commit 247fd27

Browse files
committed
Basic labeled tensor functionality
1 parent 87ecb5f commit 247fd27

File tree

10 files changed

+733
-0
lines changed

10 files changed

+733
-0
lines changed

.github/workflows/test.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ jobs:
8282
install-numba: [0]
8383
install-jax: [0]
8484
install-torch: [0]
85+
install-xarray: [0]
8586
part:
8687
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
8788
- "tests/scan"
@@ -115,6 +116,7 @@ jobs:
115116
install-numba: 0
116117
install-jax: 0
117118
install-torch: 0
119+
install-xarray: 0
118120
- install-numba: 1
119121
os: "ubuntu-latest"
120122
python-version: "3.10"
@@ -150,6 +152,13 @@ jobs:
150152
fast-compile: 0
151153
float32: 0
152154
part: "tests/link/pytorch"
155+
- install-xarray: 1
156+
os: "ubuntu-latest"
157+
python-version: "3.13"
158+
numpy-version: ">=2.0"
159+
fast-compile: 0
160+
float32: 0
161+
part: "tests/xtensor"
153162
- os: macos-15
154163
python-version: "3.13"
155164
numpy-version: ">=2.0"
@@ -196,6 +205,7 @@ jobs:
196205
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
197206
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
198207
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
208+
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
199209
pip install pytest-sphinx
200210
201211
pip install -e ./
@@ -212,6 +222,7 @@ jobs:
212222
INSTALL_NUMBA: ${{ matrix.install-numba }}
213223
INSTALL_JAX: ${{ matrix.install-jax }}
214224
INSTALL_TORCH: ${{ matrix.install-torch}}
225+
INSTALL_XARRAY: ${{ matrix.install-xarray }}
215226
OS: ${{ matrix.os}}
216227

217228
- name: Run tests

pytensor/xtensor/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
xtensor,
8+
xtensor_constant,
9+
)
10+
11+
12+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from collections.abc import Sequence
2+
3+
from pytensor.compile import ViewOp
4+
from pytensor.graph import Apply, Op
5+
from pytensor.link.c.op import COp
6+
from pytensor.link.jax.linker import jax_funcify
7+
from pytensor.link.numba.linker import numba_funcify
8+
from pytensor.link.pytorch.linker import pytorch_funcify
9+
from pytensor.tensor.type import TensorType
10+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
11+
12+
13+
class XOp(Op):
14+
"""A base class for XOps that shouldn't be materialized"""
15+
16+
def perform(self, node, inputs, outputs):
17+
raise NotImplementedError(
18+
f"xtensor operation {self} must be lowered to equivalent tensor operations"
19+
)
20+
21+
22+
class XTypeCastOp(COp):
23+
"""Base class for Ops that type cast between TensorType and XTensorType.
24+
25+
This is like a `ViewOp` but without the expectation the input and output have identical types.
26+
"""
27+
28+
view_map = {0: [0]}
29+
30+
def perform(self, node, inputs, output_storage):
31+
output_storage[0][0] = inputs[0]
32+
33+
def c_code(self, node, nodename, inp, out, sub):
34+
(iname,) = inp
35+
(oname,) = out
36+
fail = sub["fail"]
37+
38+
code, _ = ViewOp.c_code_and_version[TensorType]
39+
return code % locals()
40+
41+
def c_code_cache_version(self):
42+
_, version = ViewOp.c_code_and_version[TensorType]
43+
return (version,)
44+
45+
46+
@numba_funcify.register(XTypeCastOp)
47+
def numba_funcify_XCast(op, *args, **kwargs):
48+
from pytensor.link.numba.dispatch.basic import numba_njit
49+
50+
@numba_njit
51+
def xcast(x):
52+
return x
53+
54+
return xcast
55+
56+
57+
@jax_funcify.register(XTypeCastOp)
58+
@pytorch_funcify.register(XTypeCastOp)
59+
def funcify_XCast(op, *args, **kwargs):
60+
def xcast(x):
61+
return x
62+
63+
return xcast
64+
65+
66+
class TensorFromXTensor(XTypeCastOp):
67+
__props__ = ()
68+
69+
def make_node(self, x):
70+
if not isinstance(x.type, XTensorType):
71+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
72+
output = TensorType(x.type.dtype, shape=x.type.shape)()
73+
return Apply(self, [x], [output])
74+
75+
76+
tensor_from_xtensor = TensorFromXTensor()
77+
78+
79+
class XTensorFromTensor(XTypeCastOp):
80+
__props__ = ("dims",)
81+
82+
def __init__(self, dims: Sequence[str]):
83+
super().__init__()
84+
self.dims = tuple(dims)
85+
86+
def make_node(self, x):
87+
if not isinstance(x.type, TensorType):
88+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
89+
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
90+
return Apply(self, [x], [output])
91+
92+
93+
def xtensor_from_tensor(x, dims):
94+
return XTensorFromTensor(dims=dims)(x)
95+
96+
97+
class Rename(XTypeCastOp):
98+
__props__ = ("new_dims",)
99+
100+
def __init__(self, new_dims: tuple[str, ...]):
101+
super().__init__()
102+
self.new_dims = new_dims
103+
104+
def make_node(self, x):
105+
x = as_xtensor(x)
106+
output = x.type.clone(dims=self.new_dims)()
107+
return Apply(self, [x], [output])
108+
109+
110+
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
111+
if name_dict is not None:
112+
if names:
113+
raise ValueError("Cannot use both positional and keyword names in rename")
114+
names = name_dict
115+
116+
x = as_xtensor(x)
117+
old_names = x.type.dims
118+
new_names = list(old_names)
119+
for old_name, new_name in names.items():
120+
try:
121+
new_names[old_names.index(old_name)] = new_name
122+
except IndexError:
123+
raise ValueError(
124+
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
125+
)
126+
127+
return Rename(tuple(new_names))(x)

pytensor/xtensor/readme.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# XTensor Module
2+
3+
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
4+
5+
A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute,
6+
that labels the dimensions of the tensor.
7+
8+
Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects.
9+
10+
The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations.
11+
These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into
12+
a regular tensor graph that can itself be evaluated as usual.
13+
14+
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
15+
If the existing XOps can be composed to produce the desired result, then we can use them directly.
16+
17+
## Coordinates
18+
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
19+
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
20+
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
21+
22+
## Example
23+
24+
```python
25+
import pytensor.tensor as pt
26+
import pytensor.xtensor as px
27+
28+
a = pt.tensor("a", shape=(3,))
29+
b = pt.tensor("b", shape=(4,))
30+
31+
ax = px.as_xtensor(a, dims=["x"])
32+
bx = px.as_xtensor(b, dims=["y"])
33+
34+
zx = ax + bx
35+
assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
36+
37+
z = zx.values
38+
z.dprint()
39+
# TensorFromXTensor [id A]
40+
# └─ XElemwise{scalar_op=Add()} [id B]
41+
# ├─ XTensorFromTensor{dims=('x',)} [id C]
42+
# │ └─ a [id D]
43+
# └─ XTensorFromTensor{dims=('y',)} [id E]
44+
# └─ b [id F]
45+
```
46+
47+
Once we compile the graph, no `XOp`s are left.
48+
49+
```python
50+
import pytensor
51+
52+
with pytensor.config.change_flags(optimizer_verbose=True):
53+
fn = pytensor.function([a, b], z)
54+
55+
# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
56+
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
57+
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
58+
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
59+
60+
fn.dprint()
61+
# Add [id A] 2
62+
# ├─ ExpandDims{axis=1} [id B] 1
63+
# │ └─ a [id C]
64+
# └─ ExpandDims{axis=0} [id D] 0
65+
# └─ b [id E]
66+
```
67+
68+
69+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.xtensor.rewriting.basic
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor.basic import register_infer_shape
3+
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
4+
from pytensor.xtensor.basic import (
5+
Rename,
6+
TensorFromXTensor,
7+
XTensorFromTensor,
8+
xtensor_from_tensor,
9+
)
10+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
11+
12+
13+
@register_infer_shape
14+
@register_useless
15+
@register_canonicalize
16+
@register_xcanonicalize
17+
@node_rewriter(tracks=[TensorFromXTensor])
18+
def useless_tensor_from_xtensor(fgraph, node):
19+
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
20+
[x] = node.inputs
21+
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
22+
return [x.owner.inputs[0]]
23+
24+
25+
@register_infer_shape
26+
@register_useless
27+
@register_canonicalize
28+
@register_xcanonicalize
29+
@node_rewriter(tracks=[XTensorFromTensor])
30+
def useless_xtensor_from_tensor(fgraph, node):
31+
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
32+
[x] = node.inputs
33+
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
34+
return [x.owner.inputs[0]]
35+
36+
37+
@register_xcanonicalize
38+
@node_rewriter(tracks=[TensorFromXTensor])
39+
def useless_tensor_from_xtensor_of_rename(fgraph, node):
40+
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
41+
[renamed_x] = node.inputs
42+
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
43+
[x] = renamed_x.owner.inputs
44+
return node.op(x, return_list=True)
45+
46+
47+
@register_xcanonicalize
48+
@node_rewriter(tracks=[Rename])
49+
def useless_rename(fgraph, node):
50+
"""
51+
52+
Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims)
53+
Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims)
54+
"""
55+
[renamed_x] = node.inputs
56+
if renamed_x.owner:
57+
if isinstance(renamed_x.owner.op, Rename):
58+
[x] = renamed_x.owner.inputs
59+
return [node.op(x)]
60+
elif isinstance(renamed_x.owner.op, TensorFromXTensor):
61+
[x] = renamed_x.owner.inputs
62+
return [xtensor_from_tensor(x, dims=node.op.new_dims)]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import NodeRewriter
3+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
4+
5+
6+
optdb.register(
7+
"xcanonicalize",
8+
EquilibriumDB(ignore_newtrees=False),
9+
"fast_run",
10+
"fast_compile",
11+
"xtensor",
12+
position=0,
13+
)
14+
15+
16+
def register_xcanonicalize(
17+
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
18+
):
19+
if isinstance(node_rewriter, str):
20+
21+
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
22+
return register_xcanonicalize(
23+
inner_rewriter, node_rewriter, *tags, **kwargs
24+
)
25+
26+
return register
27+
28+
else:
29+
name = kwargs.pop("name", None) or node_rewriter.__name__ # type: ignore
30+
optdb["xtensor"].register(
31+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
32+
)
33+
return node_rewriter

0 commit comments

Comments
 (0)