Skip to content

Commit 8dfd588

Browse files
committed
experiment with first class dim objects
1 parent 2ab60b3 commit 8dfd588

File tree

8 files changed

+485
-83
lines changed

8 files changed

+485
-83
lines changed

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytensor.xtensor.shape import concat
77
from pytensor.xtensor.type import (
88
as_xtensor,
9+
dim,
910
xtensor,
1011
xtensor_constant,
1112
)

pytensor/xtensor/basic.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
from collections.abc import Sequence
2-
31
from pytensor.compile.ops import TypeCastingOp
42
from pytensor.graph import Apply, Op
3+
from pytensor.scalar.basic import uint64
4+
from pytensor.tensor.basic import ones as tensor_ones
5+
from pytensor.tensor.basic import zeros as tensor_zeros
6+
from pytensor.tensor.shape import specify_shape
57
from pytensor.tensor.type import TensorType
6-
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
8+
from pytensor.xtensor.type import DimVariable, XTensorType, as_xtensor, xtensor
9+
10+
11+
DIM_LENGTH_SCALAR = uint64
712

813

914
class XOp(Op):
@@ -32,6 +37,7 @@ def make_node(self, x):
3237
return Apply(self, [x], [output])
3338

3439
def L_op(self, inputs, outs, g_outs):
40+
# TODO fix
3541
[x] = inputs
3642
[g_out] = g_outs
3743
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
@@ -41,46 +47,49 @@ def L_op(self, inputs, outs, g_outs):
4147

4248

4349
class XTensorFromTensor(XTypeCastOp):
44-
__props__ = ("dims",)
45-
46-
def __init__(self, dims: Sequence[str]):
47-
super().__init__()
48-
self.dims = tuple(dims)
50+
__props__ = ()
4951

50-
def make_node(self, x):
52+
def make_node(self, x, *dims):
5153
if not isinstance(x.type, TensorType):
5254
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
53-
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
54-
return Apply(self, [x], [output])
55+
output = xtensor(dtype=x.type.dtype, dims=dims)
56+
return Apply(self, [x, *dims], [output])
5557

5658
def L_op(self, inputs, outs, g_outs):
59+
# TODO fix
5760
[g_out] = g_outs
5861
return [tensor_from_xtensor(g_out)]
5962

6063

61-
def xtensor_from_tensor(x, dims, name=None):
62-
return XTensorFromTensor(dims=dims)(x, name=name)
64+
def xtensor_from_tensor(x, dims, name=None, check: bool = True):
65+
if check:
66+
x = specify_shape(x, [dim.size for dim in dims])
67+
return XTensorFromTensor()(x, *dims, name=name)
6368

6469

65-
class Rename(XTypeCastOp):
66-
__props__ = ("new_dims",)
70+
class MapDims(XTypeCastOp):
71+
__props__ = ("new_dim_indices",)
6772

68-
def __init__(self, new_dims: tuple[str, ...]):
69-
super().__init__()
70-
self.new_dims = new_dims
73+
def __init__(self, new_dim_indices: tuple[int, ...]):
74+
self.new_dims_indices = new_dim_indices
7175

72-
def make_node(self, x):
76+
def make_node(self, x, *new_dims):
7377
x = as_xtensor(x)
74-
output = x.type.clone(dims=self.new_dims)()
78+
new_dims = list(x.dims)
79+
for i, idx in enumerate(self.new_dims_indices):
80+
new_dims[idx] = new_dims[i]
81+
82+
output = x.type.clone(dims=new_dims)()
7583
return Apply(self, [x], [output])
7684

7785
def L_op(self, inputs, outs, g_outs):
86+
# TODO fix
7887
[x] = inputs
7988
[g_out] = g_outs
80-
return [rename(g_out, dims=x.type.dims)]
89+
return [map_dims(g_out, dims=x.type.dims)]
8190

8291

83-
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
92+
def map_dims(x, name_dict: dict[DimVariable, DimVariable] | None = None, **names):
8493
if name_dict is not None:
8594
if names:
8695
raise ValueError("Cannot use both positional and keyword names in rename")
@@ -97,4 +106,30 @@ def rename(x, name_dict: dict[str, str] | None = None, **names: str):
97106
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
98107
)
99108

100-
return Rename(tuple(new_names))(x)
109+
return MapDims(tuple(new_names))(x)
110+
111+
112+
def zeros(*dims, dtype=None, name=None):
113+
"""Create a new XTensor filled with zeros."""
114+
if not dims:
115+
raise ValueError("At least one dimension must be specified")
116+
117+
return xtensor_from_tensor(
118+
tensor_zeros(shape=[dim.size for dim in dims], dtype=dtype),
119+
dims=dims,
120+
name=name,
121+
check=False,
122+
)
123+
124+
125+
def ones(*dims, dtype=None, name=None):
126+
"""Create a new XTensor filled with zeros."""
127+
if not dims:
128+
raise ValueError("At least one dimension must be specified")
129+
130+
return xtensor_from_tensor(
131+
tensor_ones(shape=[dim.size for dim in dims], dtype=dtype),
132+
dims=dims,
133+
name=name,
134+
check=False,
135+
)

pytensor/xtensor/dims.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from __future__ import annotations
2+
3+
from uuid import uuid4
4+
5+
import numpy as np
6+
7+
from pytensor.graph.basic import Apply
8+
from pytensor.graph.op import Op, Variable
9+
from pytensor.scalar.basic import ScalarVariable
10+
from pytensor.xtensor.type import (
11+
DIM_LENGTH_SCALAR,
12+
BaseDim,
13+
CloneDim,
14+
DimType,
15+
DimVariable,
16+
XTensorVariable,
17+
)
18+
19+
20+
class DimOp(Op):
21+
def perform(self, node, inputs, outputs):
22+
raise NotImplementedError(
23+
f"xtensor operation {self} must be lowered to equivalent tensor operations"
24+
)
25+
26+
27+
# Not a dim op, because it doesn't return a DimVariable
28+
class Length(Op):
29+
__props__ = ()
30+
31+
def make_node(self, *inputs: Variable) -> Apply:
32+
(x,) = inputs
33+
if not isinstance(x, DimVariable):
34+
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
35+
return Apply(self, [x], [DIM_LENGTH_SCALAR()])
36+
37+
def perform(self, node, inputs, outputs):
38+
outputs[0][0] = inputs[0]
39+
40+
41+
def _dim_size(dim: DimVariable) -> ScalarVariable:
42+
return Length()(dim)
43+
44+
45+
class FromLength(DimOp):
46+
__props__ = ("dim_type",)
47+
48+
def __init__(self, dim_type: DimType):
49+
super().__init__()
50+
self.dim_type = dim_type
51+
52+
def make_node(self, *inputs: Variable) -> Apply:
53+
(length,) = inputs
54+
if not isinstance(length, ScalarVariable):
55+
raise TypeError(f"length must be a ScalarVariable, got {type(length.type)}")
56+
if length.type != DIM_LENGTH_SCALAR:
57+
raise TypeError(
58+
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
59+
)
60+
return Apply(self, [length], [self.dim_type()])
61+
62+
def perform(self, node, inputs, outputs):
63+
"""Convert the length to a list of lengths."""
64+
outputs[0][0] = inputs[0]
65+
66+
67+
def from_length(length: ScalarVariable, name: str | None = None) -> DimVariable:
68+
# TODO add check for dtype
69+
if not isinstance(length, ScalarVariable):
70+
raise TypeError(f"length must be a ScalarVariable, got {type(length.type)}")
71+
if length.type != DIM_LENGTH_SCALAR:
72+
raise TypeError(
73+
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
74+
)
75+
76+
uuid = uuid4()
77+
dim_type = DimType(dim=BaseDim(uuid=uuid, name=name))
78+
op = FromLength(dim_type)
79+
return op(length, name=name)
80+
81+
82+
class FromTensor(Op):
83+
__props__ = ("dim_type",)
84+
85+
def __init__(self, dim_type: DimType):
86+
super().__init__()
87+
self.dim_type = dim_type
88+
89+
def make_node(self, *inputs: Variable) -> Apply:
90+
(x,) = inputs
91+
if not isinstance(x, XTensorVariable):
92+
raise TypeError(f"x must be an XTensorVariable, got {type(x.type)}")
93+
return Apply(self, [x], [self.dim_type()])
94+
95+
def perform(self, node, inputs, outputs):
96+
"""Convert the tensor to a dimension variable."""
97+
(x,) = inputs
98+
(x_var,) = node.inputs
99+
for i, dim in enumerate(x_var.type.dims):
100+
if dim == self.dim_type.dim:
101+
outputs[0][0] = x.shape[i]
102+
return
103+
raise ValueError(
104+
f"Dimension {self.dim_type.dim} not found in tensor {x.type.dims}"
105+
)
106+
107+
108+
def _dim_from_tensor(x: XTensorVariable, idx: int) -> DimVariable:
109+
op = FromTensor(dim_type=DimType(x.type.dims[idx]))
110+
return op(x, name=x.type.dims[idx].name)
111+
112+
113+
class Clone(Op):
114+
__props__ = ("dim_type",)
115+
116+
def __init__(self, dim_type):
117+
super().__init__()
118+
self.dim_type = dim_type
119+
120+
def make_node(self, *inputs: Variable) -> Apply:
121+
(x,) = inputs
122+
if not isinstance(x, DimVariable):
123+
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
124+
return Apply(self, [x], [self.dim_type()])
125+
126+
def perform(self, node, inputs, outputs):
127+
outputs[0][0] = inputs[0]
128+
129+
130+
def _clone_dim(dim: DimVariable, *, name: str | None = None) -> DimVariable:
131+
"""Rename a dimension variable.
132+
133+
Args:
134+
name: The new name for the dimension.
135+
136+
Returns:
137+
A new DimVariable with the updated name.
138+
"""
139+
dim_type = DimType(dim=CloneDim(uuid=uuid4(), base=dim.type.dim))
140+
return Clone(dim_type)(dim, name=name)
141+
142+
143+
class Product(Op):
144+
__props__ = ()
145+
146+
def make_node(self, *dims: Variable) -> Apply:
147+
if not all(isinstance(dim, DimVariable) for dim in dims):
148+
raise TypeError("All inputs must be DimVariables.")
149+
out = dim_type()
150+
return Apply(self, list(dims), [out])
151+
152+
def perform(self, node, inputs, outputs):
153+
outputs[0][0] = np.prod(inputs, dtype=DIM_LENGTH_SCALAR.dtype).item()
154+
155+
156+
def product_dim(*dims: DimVariable, name: str | None = None) -> DimVariable:
157+
return Product()(*dims, name=name)
158+
159+
160+
def rebase_dim(dim: DimVariable, *tensors: XTensorVariable) -> DimVariable:
161+
if not isinstance(dim, DimVariable):
162+
raise TypeError(f"dim must be a DimVariable, got {type(dim)}")
163+
164+
if not tensors:
165+
raise ValueError("At least one tensor must be provided for rebasing.")
166+
167+
for tensor in tensors:
168+
for i, tensor_dim in enumerate(tensor.type.dims):
169+
if dim.type.dim == tensor_dim:
170+
return _dim_from_tensor(tensor, idx=i)
171+
raise ValueError(
172+
f"Dimension {dim.type.dim} not found in any of the provided tensors."
173+
)

pytensor/xtensor/rewriting/basic.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,34 @@
11
from pytensor.graph import node_rewriter
2-
from pytensor.tensor.basic import register_infer_shape
3-
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
42
from pytensor.xtensor.basic import (
5-
Rename,
3+
MapDims,
64
TensorFromXTensor,
75
XTensorFromTensor,
86
xtensor_from_tensor,
97
)
108
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
119

1210

13-
@register_infer_shape
14-
@register_useless
15-
@register_canonicalize
16-
@register_lower_xtensor
17-
@node_rewriter(tracks=[TensorFromXTensor])
11+
# @register_infer_shape
12+
# @register_useless
13+
# @register_canonicalize
14+
# @register_lower_xtensor
15+
# @node_rewriter(tracks=[TensorFromXTensor])
1816
def useless_tensor_from_xtensor(fgraph, node):
1917
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
2018
[x] = node.inputs
2119
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
2220
return [x.owner.inputs[0]]
2321

2422

25-
@register_infer_shape
26-
@register_useless
27-
@register_canonicalize
28-
@register_lower_xtensor
29-
@node_rewriter(tracks=[XTensorFromTensor])
23+
# @register_infer_shape
24+
# @register_useless
25+
# @register_canonicalize
26+
# @register_lower_xtensor
27+
# @node_rewriter(tracks=[XTensorFromTensor])
3028
def useless_xtensor_from_tensor(fgraph, node):
3129
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
32-
[x] = node.inputs
30+
# TODO
31+
[x, *dims] = node.inputs
3332
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
3433
return [x.owner.inputs[0]]
3534

@@ -39,13 +38,13 @@ def useless_xtensor_from_tensor(fgraph, node):
3938
def useless_tensor_from_xtensor_of_rename(fgraph, node):
4039
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
4140
[renamed_x] = node.inputs
42-
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
41+
if renamed_x.owner and isinstance(renamed_x.owner.op, MapDims):
4342
[x] = renamed_x.owner.inputs
4443
return node.op(x, return_list=True)
4544

4645

4746
@register_lower_xtensor
48-
@node_rewriter(tracks=[Rename])
47+
@node_rewriter(tracks=[MapDims])
4948
def useless_rename(fgraph, node):
5049
"""
5150
@@ -54,7 +53,7 @@ def useless_rename(fgraph, node):
5453
"""
5554
[renamed_x] = node.inputs
5655
if renamed_x.owner:
57-
if isinstance(renamed_x.owner.op, Rename):
56+
if isinstance(renamed_x.owner.op, MapDims):
5857
[x] = renamed_x.owner.inputs
5958
return [node.op(x)]
6059
elif isinstance(renamed_x.owner.op, TensorFromXTensor):

0 commit comments

Comments
 (0)