Skip to content

Commit 2054ad8

Browse files
committed
WIP Basic labeled tensor functionality
TODO: Split Stack from commit
1 parent 2610862 commit 2054ad8

File tree

12 files changed

+837
-0
lines changed

12 files changed

+837
-0
lines changed

pytensor/xtensor/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
xtensor,
8+
xtensor_constant,
9+
)
10+
11+
12+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from collections.abc import Sequence
2+
3+
from pytensor.graph import Apply, Op
4+
from pytensor.tensor.type import TensorType
5+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
6+
7+
8+
class XOp(Op):
9+
"""A base class for XOps that shouldn't be materialized"""
10+
11+
def perform(self, node, inputs, outputs):
12+
raise NotImplementedError(
13+
f"xtensor operation {self} must be lowered to equivalent tensor operations"
14+
)
15+
16+
17+
class XViewOp(Op):
18+
# Make this a View Op with C-implementation
19+
view_map = {0: [0]}
20+
21+
def perform(self, node, inputs, output_storage):
22+
output_storage[0][0] = inputs[0]
23+
24+
25+
class TensorFromXTensor(XViewOp):
26+
__props__ = ()
27+
28+
def make_node(self, x) -> Apply:
29+
if not isinstance(x.type, XTensorType):
30+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
31+
output = TensorType(x.type.dtype, shape=x.type.shape)()
32+
return Apply(self, [x], [output])
33+
34+
35+
tensor_from_xtensor = TensorFromXTensor()
36+
37+
38+
class XTensorFromTensor(XViewOp):
39+
__props__ = ("dims",)
40+
41+
def __init__(self, dims: Sequence[str]):
42+
super().__init__()
43+
self.dims = tuple(dims)
44+
45+
def make_node(self, x) -> Apply:
46+
if not isinstance(x.type, TensorType):
47+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
48+
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
49+
return Apply(self, [x], [output])
50+
51+
52+
def xtensor_from_tensor(x, dims):
53+
return XTensorFromTensor(dims=dims)(x)
54+
55+
56+
class Rename(XViewOp):
57+
__props__ = ("new_dims",)
58+
59+
def __init__(self, new_dims: tuple[str, ...]):
60+
super().__init__()
61+
self.new_dims = new_dims
62+
63+
def make_node(self, x):
64+
x = as_xtensor(x)
65+
output = x.type.clone(dims=self.new_dims)()
66+
return Apply(self, [x], [output])
67+
68+
69+
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
70+
if name_dict is not None:
71+
if names:
72+
raise ValueError("Cannot use both positional and keyword names in rename")
73+
names = name_dict
74+
75+
x = as_xtensor(x)
76+
old_names = x.type.dims
77+
new_names = list(old_names)
78+
for old_name, new_name in names.items():
79+
try:
80+
new_names[old_names.index(old_name)] = new_name
81+
except IndexError:
82+
raise ValueError(
83+
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
84+
)
85+
86+
return Rename(tuple(new_names))(x)

pytensor/xtensor/readme.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# XTensor Module
2+
3+
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
4+
5+
A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute,
6+
that labels the dimensions of the tensor.
7+
8+
Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects.
9+
10+
The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations.
11+
These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into
12+
a regular tensor graph that can itself be evaluated as usual.
13+
14+
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
15+
If the existing XOps can be composed to produce the desired result, then we can use them directly.
16+
17+
## Coordinates
18+
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
19+
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
20+
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
21+
22+
## Example
23+
24+
```python
25+
import pytensor.tensor as pt
26+
import pytensor.xtensor as px
27+
28+
a = pt.tensor("a", shape=(3,))
29+
b = pt.tensor("b", shape=(4,))
30+
31+
ax = px.as_xtensor(a, dims=["x"])
32+
bx = px.as_xtensor(b, dims=["y"])
33+
34+
zx = ax + bx
35+
assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
36+
37+
z = zx.values
38+
z.dprint()
39+
# TensorFromXTensor [id A]
40+
# └─ XElemwise{scalar_op=Add()} [id B]
41+
# ├─ XTensorFromTensor{dims=('x',)} [id C]
42+
# │ └─ a [id D]
43+
# └─ XTensorFromTensor{dims=('y',)} [id E]
44+
# └─ b [id F]
45+
```
46+
47+
Once we compile the graph, no `XOp`s are left.
48+
49+
```python
50+
import pytensor
51+
52+
with pytensor.config.change_flags(optimizer_verbose=True):
53+
fn = pytensor.function([a, b], z)
54+
55+
# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
56+
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
57+
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
58+
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
59+
60+
fn.dprint()
61+
# Add [id A] 2
62+
# ├─ ExpandDims{axis=1} [id B] 1
63+
# │ └─ a [id C]
64+
# └─ ExpandDims{axis=0} [id D] 0
65+
# └─ b [id E]
66+
```
67+
68+
69+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.shape
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.xtensor.basic import (
3+
Rename,
4+
TensorFromXTensor,
5+
XTensorFromTensor,
6+
xtensor_from_tensor,
7+
)
8+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
9+
10+
11+
@register_xcanonicalize
12+
@node_rewriter(tracks=[TensorFromXTensor])
13+
def useless_tensor_from_xtensor(fgraph, node):
14+
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
15+
[x] = node.inputs
16+
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
17+
return [x.owner.inputs[0]]
18+
19+
20+
@register_xcanonicalize
21+
@node_rewriter(tracks=[XTensorFromTensor])
22+
def useless_xtensor_from_tensor(fgraph, node):
23+
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
24+
[x] = node.inputs
25+
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
26+
return [x.owner.inputs[0]]
27+
28+
29+
@register_xcanonicalize
30+
@node_rewriter(tracks=[TensorFromXTensor])
31+
def useless_tensor_from_xtensor_of_rename(fgraph, node):
32+
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
33+
[renamed_x] = node.inputs
34+
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
35+
[x] = renamed_x.owner.inputs
36+
return node.op(x, return_list=True)
37+
38+
39+
@register_xcanonicalize
40+
@node_rewriter(tracks=[Rename])
41+
def useless_rename(fgraph, node):
42+
"""
43+
44+
Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims)
45+
Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims)
46+
"""
47+
[renamed_x] = node.inputs
48+
if renamed_x.owner:
49+
if isinstance(renamed_x.owner.op, Rename):
50+
[x] = renamed_x.owner.inputs
51+
return [node.op(x)]
52+
elif isinstance(renamed_x.owner.op, TensorFromXTensor):
53+
[x] = renamed_x.owner.inputs
54+
return [xtensor_from_tensor(x, dims=node.op.new_dims)]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import moveaxis
3+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
4+
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5+
from pytensor.xtensor.shape import Stack
6+
7+
8+
@register_xcanonicalize
9+
@node_rewriter(tracks=[Stack])
10+
def lower_stack(fgraph, node):
11+
[x] = node.inputs
12+
batch_ndim = x.type.ndim - len(node.op.stacked_dims)
13+
stacked_axes = [
14+
i for i, dim in enumerate(x.type.dims) if dim in node.op.stacked_dims
15+
]
16+
end = tuple(range(-len(stacked_axes), 0))
17+
18+
x_tensor = tensor_from_xtensor(x)
19+
x_tensor_transposed = moveaxis(x_tensor, source=stacked_axes, destination=end)
20+
if batch_ndim == (x.type.ndim - 1):
21+
# This happens when we stack a "single" dimension, in this case all we need is the transpose
22+
# Note: If we have meaningful rewrites before lowering, consider canonicalizing this as a Transpose + Rename
23+
final_tensor = x_tensor_transposed
24+
else:
25+
final_shape = (*tuple(x_tensor_transposed.shape)[:batch_ndim], -1)
26+
final_tensor = x_tensor_transposed.reshape(final_shape)
27+
28+
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
29+
return [new_out]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import NodeRewriter
3+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
4+
5+
6+
optdb.register(
7+
"xcanonicalize",
8+
EquilibriumDB(ignore_newtrees=False),
9+
"fast_run",
10+
"fast_compile",
11+
"xtensor",
12+
position=0,
13+
)
14+
15+
16+
def register_xcanonicalize(
17+
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
18+
):
19+
if isinstance(node_rewriter, str):
20+
21+
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
22+
return register_xcanonicalize(
23+
inner_rewriter, node_rewriter, *tags, **kwargs
24+
)
25+
26+
return register
27+
28+
else:
29+
name = kwargs.pop("name", None) or node_rewriter.__name__
30+
optdb["xtensor"].register(
31+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
32+
)
33+
return node_rewriter

pytensor/xtensor/shape.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from collections.abc import Sequence
2+
3+
from pytensor.graph import Apply
4+
from pytensor.xtensor.basic import XOp
5+
from pytensor.xtensor.type import as_xtensor, xtensor
6+
7+
8+
class Stack(XOp):
9+
__props__ = ("new_dim_name", "stacked_dims")
10+
11+
def __init__(self, new_dim_name: str, stacked_dims: tuple[str, ...]):
12+
super().__init__()
13+
if new_dim_name in stacked_dims:
14+
raise ValueError(
15+
f"Stacking dim {new_dim_name} must not be in {stacked_dims}"
16+
)
17+
if not stacked_dims:
18+
raise ValueError(f"Stacking dims must not be empty: got {stacked_dims}")
19+
self.new_dim_name = new_dim_name
20+
self.stacked_dims = stacked_dims
21+
22+
def make_node(self, x):
23+
x = as_xtensor(x)
24+
if not (set(self.stacked_dims) <= set(x.type.dims)):
25+
raise ValueError(
26+
f"Stacking dims {self.stacked_dims} must be a subset of {x.type.dims}"
27+
)
28+
if self.new_dim_name in x.type.dims:
29+
raise ValueError(
30+
f"Stacking dim {self.new_dim_name} must not be in {x.type.dims}"
31+
)
32+
if len(self.stacked_dims) == x.type.ndim:
33+
batch_dims, batch_shape = (), ()
34+
else:
35+
batch_dims, batch_shape = zip(
36+
*(
37+
(dim, shape)
38+
for dim, shape in zip(x.type.dims, x.type.shape)
39+
if dim not in self.stacked_dims
40+
)
41+
)
42+
stack_shape = 1
43+
for dim, shape in zip(x.type.dims, x.type.shape):
44+
if dim in self.stacked_dims:
45+
if shape is None:
46+
stack_shape = None
47+
break
48+
else:
49+
stack_shape *= shape
50+
output = xtensor(
51+
dtype=x.type.dtype,
52+
shape=(*batch_shape, stack_shape),
53+
dims=(*batch_dims, self.new_dim_name),
54+
)
55+
return Apply(self, [x], [output])
56+
57+
58+
def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
59+
if dim is not None:
60+
if dims:
61+
raise ValueError("Cannot use both positional dim and keyword dims in stack")
62+
dims = dim
63+
64+
y = x
65+
for new_dim_name, stacked_dims in dims.items():
66+
if isinstance(stacked_dims, str):
67+
raise TypeError(
68+
f"Stacking dims must be a sequence of strings, got a single string: {stacked_dims}"
69+
)
70+
y = Stack(new_dim_name, tuple(stacked_dims))(y)
71+
return y

0 commit comments

Comments
 (0)