Skip to content

Commit 2d36e9f

Browse files
committed
Implement index operations on XTensorTypes
1 parent 8c2d953 commit 2d36e9f

File tree

2 files changed

+143
-8
lines changed

2 files changed

+143
-8
lines changed

pytensor/xtensor/indexing.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from pytensor.graph.basic import Apply, Constant, Variable
2+
from pytensor.scalar.basic import discrete_dtypes
3+
from pytensor.tensor.basic import as_tensor
4+
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
5+
from pytensor.xtensor.basic import XOp
6+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
7+
8+
9+
def as_idx_variable(idx):
10+
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
11+
raise TypeError("XTensors do not support None (np.newaxis)")
12+
idx
13+
if isinstance(idx, slice):
14+
idx = make_slice(idx)
15+
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
16+
pass
17+
else:
18+
# Must be integer indices, we already counted for None and slices
19+
try:
20+
idx = as_xtensor(idx)
21+
except TypeError:
22+
idx = as_tensor(idx)
23+
if idx.type.dtype not in discrete_dtypes:
24+
raise TypeError("Numerical indices must be integers or boolean")
25+
if idx.type.dtype == "bool" and idx.type.ndim == 0:
26+
raise NotImplementedError("Scalar boolean indices not supported")
27+
return idx
28+
29+
30+
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
31+
if dim_length is None:
32+
return None
33+
if isinstance(slc, Constant):
34+
d = slc.data
35+
start, stop, step = d.start, d.stop, d.step
36+
elif slc.owner is None:
37+
# It's a root variable no way of knowing what we're getting
38+
return None
39+
else:
40+
# It's a MakeSliceOp
41+
start, stop, step = slice.owner.inputs
42+
if isinstance(start, Constant):
43+
start = start.data
44+
else:
45+
return None
46+
if isinstance(stop, Constant):
47+
stop = stop.data
48+
else:
49+
return None
50+
if isinstance(step, Constant):
51+
step = step.data
52+
else:
53+
return None
54+
return len(range(*slice(start, stop, step).indices(dim_length)))
55+
56+
57+
class Index(XOp):
58+
__props__ = ()
59+
60+
def make_node(self, x, *idxs):
61+
x = as_xtensor(x)
62+
idxs = [as_idx_variable(idx) for idx in idxs]
63+
64+
x_ndim = x.type.ndim
65+
x_dims = x.type.dims
66+
x_shape = x.type.shape
67+
out_dims = []
68+
out_shape = []
69+
for i, idx in enumerate(idxs):
70+
if i == x_ndim:
71+
raise IndexError("Too many indices")
72+
if isinstance(idx.type, XTensorType):
73+
raise NotImplementedError(
74+
"Indexing with XTensorType not yet supported."
75+
)
76+
if isinstance(idx.type, SliceType):
77+
out_dims.append(x_dims[i])
78+
out_shape.append(get_static_slice_length(idx, x_shape[i]))
79+
80+
else: # TensorType
81+
if idx.type.ndim == 0:
82+
# Scalar, dimension is dropped
83+
pass
84+
elif idx.type.ndim == 1:
85+
out_dims.append(x_dims[i])
86+
out_shape.append(idx.type.shape[0])
87+
else:
88+
# Same error that xarray raises
89+
raise IndexError(
90+
"Unlabeled multi-dimensional array cannot be used for indexing"
91+
)
92+
93+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
94+
return Apply(self, [x], [output])
95+
96+
97+
index = Index()

pytensor/xtensor/type.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import warnings
2+
3+
14
try:
25
import xarray as xr
36

@@ -6,13 +9,13 @@
69
XARRAY_AVAILABLE = False
710

811
from collections.abc import Sequence
9-
from typing import TypeVar
12+
from typing import Any, Literal, TypeVar
1013

1114
import numpy as np
1215

1316
from pytensor import _as_symbolic, config
1417
from pytensor.graph import Apply, Constant
15-
from pytensor.graph.basic import Variable, OptionalApplyType
18+
from pytensor.graph.basic import OptionalApplyType, Variable
1619
from pytensor.graph.type import HasDataType, HasShape, Type
1720
from pytensor.tensor.utils import hash_from_ndarray
1821
from pytensor.utils import hash_from_code
@@ -143,15 +146,50 @@ def __getitem__(self, idx):
143146

144147
return index(self, *idx)
145148

149+
def sel(self, *args, **kwargs):
150+
raise NotImplementedError(
151+
"sel not implemented for XTensorVariable, use isel instead"
152+
)
146153

147-
class XTensorVariable(Variable):
148-
pass
154+
def isel(
155+
self,
156+
indexers: dict[str, Any] | None = None,
157+
drop: bool = False, # Unused by PyTensor
158+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
159+
**indexers_kwargs,
160+
):
161+
from pytensor.xtensor.indexing import index
162+
163+
if indexers_kwargs:
164+
if indexers is not None:
165+
raise ValueError(
166+
"Cannot pass both indexers and indexers_kwargs to isel"
167+
)
168+
indexers = indexers_kwargs
149169

150-
# def __str__(self):
151-
# return f"{self.__class__.__name__}{{{self.format},{self.dtype}}}"
170+
if missing_dims not in {"raise", "warn", "ignore"}:
171+
raise ValueError(
172+
f"Unrecognized options {missing_dims} for missing_dims argument"
173+
)
152174

153-
# def __repr__(self):
154-
# return str(self)
175+
# Sort indices and pass them to index
176+
dims = self.type.dims
177+
indices = [slice(None) for _ in self.type.ndim]
178+
for key, idx in indexers.items():
179+
try:
180+
indices[dims.index(key)] = idx
181+
except IndexError:
182+
if missing_dims == "raise":
183+
raise ValueError(
184+
f"Dimension {key} does not exist. Expected one of {dims}"
185+
)
186+
elif missing_dims == "warn":
187+
warnings.warn(
188+
UserWarning,
189+
f"Dimension {key} does not exist. Expected one of {dims}",
190+
)
191+
192+
return index(self, *indices)
155193

156194

157195
class XTensorConstantSignature(tuple):

0 commit comments

Comments
 (0)