Skip to content

Commit 1ce338c

Browse files
committed
work on tests
1 parent 2c0e780 commit 1ce338c

File tree

13 files changed

+155
-70
lines changed

13 files changed

+155
-70
lines changed

pytensor/xtensor/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pytensor.tensor.basic import zeros as tensor_zeros
66
from pytensor.tensor.shape import specify_shape
77
from pytensor.tensor.type import TensorType
8-
from pytensor.xtensor.type import DimVariable, XTensorType, as_xtensor, xtensor
8+
from pytensor.xtensor.type import DimVariable, XTensorType, as_dim, as_xtensor, xtensor
99

1010

1111
DIM_LENGTH_SCALAR = uint64
@@ -64,6 +64,7 @@ def L_op(self, inputs, outs, g_outs):
6464
def xtensor_from_tensor(x, dims, name=None, check: bool = True):
6565
if check:
6666
x = specify_shape(x, [dim.size for dim in dims])
67+
dims = [as_dim(dim) for dim in dims]
6768
return XTensorFromTensor()(x, *dims, name=name)
6869

6970

pytensor/xtensor/dims.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def perform(self, node, inputs, outputs):
4141

4242

4343
def _dim_size(dim: DimVariable) -> DIM_LENGTH_VARIABLE:
44+
if dim.type.size is not None:
45+
return DIM_LENGTH_TYPE.filter_variable(dim.type.size)
4446
return Length()(dim)
4547

4648

@@ -162,21 +164,26 @@ def product_dim(*dims: DimVariable, name: str | None = None) -> DimVariable:
162164
return Product()(*dims, name=name)
163165

164166

165-
def rebase_dim(dim: DimVariable, *tensors: XTensorVariable) -> DimVariable:
166-
if not isinstance(dim, DimVariable):
167+
def rebase_dim(dim: DimVariable | DimType, *tensors: XTensorVariable) -> DimVariable:
168+
if not isinstance(dim, DimVariable | DimType):
167169
raise TypeError(f"dim must be a DimVariable, got {type(dim)}")
168170

169171
if not tensors:
170172
raise ValueError("At least one tensor must be provided for rebasing.")
171173

174+
if isinstance(dim, DimVariable):
175+
dim_type = dim.type
176+
else:
177+
dim_type = dim
178+
172179
for tensor in tensors:
173180
for i, tensor_dim in enumerate(tensor.type.dims):
174-
if dim.type == tensor_dim:
181+
if dim_type == tensor_dim:
175182
return _dim_from_tensor(tensor, idx=i)
176-
raise ValueError(f"Dimension {dim.type} not found in any of the provided tensors.")
183+
raise ValueError(f"Dimension {dim} not found in any of the provided tensors.")
177184

178185

179186
def rebase_dims(
180-
dims: Iterable[DimVariable], *tensors: XTensorVariable
187+
dims: Iterable[DimVariable | DimType], *tensors: XTensorVariable
181188
) -> list[DimVariable]:
182189
return [rebase_dim(dim, *tensors) for dim in dims]

pytensor/xtensor/math.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from pytensor.graph.basic import Apply
1010
from pytensor.scalar.basic import _cast_mapping, upcast
1111
from pytensor.xtensor.basic import XOp, as_xtensor
12-
from pytensor.xtensor.type import AsDim, DimType, DimVariable, as_dim, xtensor
12+
from pytensor.xtensor.dims import rebase_dims
13+
from pytensor.xtensor.type import AsDim, DimType, as_dim_type, xtensor
1314
from pytensor.xtensor.vectorization import XElemwise
1415

1516

@@ -527,7 +528,7 @@ class Dot(XOp):
527528
__props__ = ("dims",)
528529

529530
def __init__(self, dims: Iterable[DimType]):
530-
self.dims = dims
531+
self.dims = frozenset(dims)
531532
super().__init__()
532533

533534
def make_node(self, x, y):
@@ -544,7 +545,7 @@ def make_node(self, x, y):
544545
# Determine output dtype
545546
out_dtype = upcast(x.type.dtype, y.type.dtype)
546547

547-
out = xtensor(dtype=out_dtype, dims=out_dims)
548+
out = xtensor(dtype=out_dtype, dims=rebase_dims(out_dims, x, y))
548549
return Apply(self, [x, y], [out])
549550

550551

@@ -589,7 +590,6 @@ def dot(x, y, dim: str | Iterable[AsDim] | EllipsisType | None = None):
589590

590591
x_dims = set(x.type.dims)
591592
y_dims = set(y.type.dims)
592-
print(x_dims, y_dims, dim)
593593
intersection = x_dims & y_dims
594594
union = x_dims | y_dims
595595

@@ -599,9 +599,9 @@ def dot(x, y, dim: str | Iterable[AsDim] | EllipsisType | None = None):
599599
elif dim is ...:
600600
dim_set = union
601601
elif isinstance(dim, Iterable):
602-
dim_set = set([as_dim(dim).type for dim in dim])
603-
elif isinstance(dim, (DimVariable, DimType, str)):
604-
dim_set = {as_dim(dim).type}
602+
dim_set = {as_dim_type(dim) for dim in dim}
603+
elif isinstance(dim, AsDim):
604+
dim_set = {as_dim_type(dim)}
605605
else:
606606
raise TypeError(f"Unknown type {dim} for dimension")
607607

pytensor/xtensor/reduction.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@
77
from pytensor.graph.basic import Apply
88
from pytensor.tensor.math import variadic_mul
99
from pytensor.xtensor.basic import XOp
10+
from pytensor.xtensor.dims import rebase_dims
1011
from pytensor.xtensor.math import neq, sqrt
1112
from pytensor.xtensor.math import sqr as square
12-
from pytensor.xtensor.type import DimType, DimVariable, as_xtensor, xtensor
13+
from pytensor.xtensor.type import (
14+
AsDim,
15+
DimType,
16+
DimVariable,
17+
as_dim_type,
18+
as_xtensor,
19+
xtensor,
20+
)
1321

1422

15-
REDUCE_DIM = DimVariable | Sequence[DimVariable] | EllipsisType | None
23+
REDUCE_DIM = DimVariable | Sequence[AsDim] | EllipsisType | None
1624

1725

1826
class XReduce(XOp):
@@ -22,28 +30,28 @@ def __init__(self, binary_op, dims: Sequence[DimVariable]):
2230
super().__init__()
2331
self.binary_op = binary_op
2432
# Order of reduce dims doesn't change the behavior of the Op
25-
self.dims = tuple(dims)
33+
self.dims = frozenset(dims)
2634

2735
def make_node(self, x):
2836
x = as_xtensor(x)
2937
x_dims = x.type.dims
3038
x_dims_set = set(x_dims)
3139
reduce_dims_set = set(self.dims)
3240
if x_dims_set == reduce_dims_set:
33-
out_dims, out_shape = [], []
41+
out_dim_types, out_shape = [], []
3442
else:
3543
if not reduce_dims_set.issubset(x_dims_set):
3644
raise ValueError(
3745
f"Reduced dims {self.dims} not found in array dimensions {x_dims}."
3846
)
39-
out_dims, out_shape = zip(
47+
out_dim_types, out_shape = zip(
4048
*[
4149
(d, s)
4250
for d, s in zip(x_dims, x.type.shape)
4351
if d not in reduce_dims_set
4452
]
4553
)
46-
output = xtensor(dtype=x.type.dtype, dims=out_dims)
54+
output = xtensor(dtype=x.type.dtype, dims=rebase_dims(out_dim_types, x))
4755
return Apply(self, [x], [output])
4856

4957

@@ -53,7 +61,7 @@ def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[DimType]:
5361
elif dim is None or dim is Ellipsis:
5462
x = as_xtensor(x)
5563
return typing.cast(tuple[DimType], x.type.dims)
56-
return tuple(dim.type for dim in dim)
64+
return tuple(as_dim_type(dim) for dim in dim)
5765

5866

5967
def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
@@ -112,9 +120,9 @@ def std(x, dim: REDUCE_DIM, *, ddof: int = 0):
112120
class XCumReduce(XOp):
113121
__props__ = ("binary_op", "dims")
114122

115-
def __init__(self, binary_op, dims: Sequence[str]):
123+
def __init__(self, binary_op, dims: Sequence[DimType]):
116124
self.binary_op = binary_op
117-
self.dims = tuple(sorted(dims)) # Order doesn't matter
125+
self.dims = frozenset(dims)
118126

119127
def make_node(self, x):
120128
x = as_xtensor(x)

pytensor/xtensor/rewriting/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ def useless_length(fgraph, node):
5151
return [dim.owner.inputs[0]]
5252

5353

54+
@register_infer_shape
55+
@register_useless
56+
@register_canonicalize
57+
@register_lower_xtensor
58+
@node_rewriter(tracks=[Length])
59+
def known_length(fgraph, node):
60+
"""Length(dim_with_size) -> size"""
61+
[dim] = node.inputs
62+
if dim.type.size is not None:
63+
return [dim.type.size]
64+
65+
5466
@register_infer_shape
5567
@register_useless
5668
@register_canonicalize

pytensor/xtensor/rewriting/reduction.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def lower_reduce(fgraph, node):
1717
x_dims = x.type.dims
1818
reduce_dims = node.op.dims
1919
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
20+
out_dims = [x_dim for x_dim in x.dims if x_dim.type not in reduce_dims]
2021

2122
if not reduce_axis:
2223
return [x]
@@ -40,7 +41,7 @@ def lower_reduce(fgraph, node):
4041

4142
x_tensor = tensor_from_xtensor(x)
4243
out_tensor = tensor_op_class(axis=reduce_axis)(x_tensor)
43-
new_out = xtensor_from_tensor(out_tensor, out.type.dims)
44+
new_out = xtensor_from_tensor(out_tensor, out_dims)
4445
return [new_out]
4546

4647

@@ -51,6 +52,7 @@ def lower_cumreduce(fgraph, node):
5152
x_dims = x.type.dims
5253
reduce_dims = node.op.dims
5354
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
55+
out_dims = [x_dim for x_dim in x.dims if x_dim not in reduce_dims]
5456

5557
if not reduce_axis:
5658
return [x]
@@ -68,5 +70,5 @@ def lower_cumreduce(fgraph, node):
6870
out_tensor = tensor_from_xtensor(x)
6971
for axis in reduce_axis:
7072
out_tensor = tensor_op_class(axis=axis)(out_tensor)
71-
out = xtensor_from_tensor(out_tensor, x.type.dims)
73+
out = xtensor_from_tensor(out_tensor, out_dims)
7274
return [out]

pytensor/xtensor/rewriting/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
77
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion
88
from pytensor.tensor.variable import TensorVariable
9-
from pytensor.xtensor.type import XTensorVariable
9+
from pytensor.xtensor.type import AsDim, XTensorVariable, as_dim_type
1010

1111

1212
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
@@ -56,8 +56,9 @@ def register(inner_rewriter: RewriteDatabase | NodeRewriter):
5656
return node_rewriter
5757

5858

59-
def lower_aligned(x: XTensorVariable, out_dims: Sequence[str]) -> TensorVariable:
59+
def lower_aligned(x: XTensorVariable, out_dims: Sequence[AsDim]) -> TensorVariable:
6060
"""Lower an XTensorVariable to a TensorVariable so that it's dimensions are aligned with "out_dims"."""
61+
out_dim_types = [as_dim_type(x) for x in out_dims]
6162
inp_dims = {d: i for i, d in enumerate(x.type.dims)}
62-
ds_order = tuple(inp_dims.get(dim, "x") for dim in out_dims)
63+
ds_order = tuple(inp_dims.get(dim, "x") for dim in out_dim_types)
6364
return typing.cast(TensorVariable, x.values.dimshuffle(ds_order))

pytensor/xtensor/shape.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
as_xtensor,
2323
xtensor,
2424
)
25-
from pytensor.xtensor.vectorization import combine_dims_and_shape
2625

2726

2827
class Stack(XOp):

pytensor/xtensor/type.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,16 +1138,36 @@ def as_symbolic_xarray(x, **kwargs):
11381138

11391139
def as_dim(
11401140
x: AsDim,
1141+
*,
1142+
allow_new: bool = False,
11411143
) -> DimVariable:
11421144
if isinstance(x, DimVariable):
11431145
return x
11441146
if isinstance(x, str):
1145-
return dim(name=x, unique=False)
1147+
if allow_new:
1148+
return dim(name=x, unique=True)
1149+
else:
1150+
raise ValueError(
1151+
f"Cannot convert string {x} to dim without allow_new=True. "
1152+
"Use `dim(name=x)` to create a new dimension."
1153+
)
11461154
if isinstance(x, DimType):
1147-
return cast(DimVariable, x())
1155+
if allow_new:
1156+
return cast(DimVariable, x())
1157+
else:
1158+
raise ValueError(
1159+
f"Cannot convert DimType {x} to dim without allow_new=True. "
1160+
"Use `x.make_variable()` to create a new dimension variable."
1161+
)
11481162
raise ValueError(f"Can not convert {type(x)} to dim.")
11491163

11501164

1165+
def as_dim_type(
1166+
x: AsDim,
1167+
) -> DimType:
1168+
return as_dim(x, allow_new=True).type
1169+
1170+
11511171
def as_xtensor(
11521172
x,
11531173
dims: Sequence[DimVariable] | None = None,

pytensor/xtensor/vectorization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DimType,
2020
DimVariable,
2121
XTensorVariable,
22+
as_dim,
2223
as_xtensor,
2324
xtensor,
2425
)
@@ -27,12 +28,15 @@
2728
def broadcast_xtensors(
2829
inputs: Sequence[XTensorVariable], exclude: Sequence[AsDim] | None = None
2930
) -> list[DimVariable]:
31+
if exclude is None:
32+
exclude = []
33+
exclude_set: set[DimType] = {as_dim(d).type for d in exclude}
3034
dims_and_shape: dict[DimType, int | None] = {}
3135
dim_to_dimvar: dict[DimType, DimVariable] = {}
3236
for inp in inputs:
3337
for dim, dim_length in zip(inp.dims, inp.type.shape):
3438
# TODO Must check dim conversion!!!
35-
if dim in exclude:
39+
if dim.type in exclude_set:
3640
continue
3741
if dim.type not in dims_and_shape:
3842
dims_and_shape[dim.type] = dim_length

0 commit comments

Comments
 (0)