-
Notifications
You must be signed in to change notification settings - Fork 145
Add transpose() for labeled tensors #1427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
25397c8
e32d865
5988cec
5936ab2
6fc7b89
0778cf7
c7ce0c9
bc2cbc0
7bfa2b2
d4f5512
1ed01c4
4f010e0
f0ea583
0125bd2
30e1a42
29b954a
a76b15e
af14c90
6208092
15f4c48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# HERE LIE DRAGONS | ||
# Uselful links to make sense of all the numpy/xarray complexity | ||
# https://numpy.org/devdocs//user/basics.indexing.html | ||
# https://numpy.org/neps/nep-0021-advanced-indexing.html | ||
# https://docs.xarray.dev/en/latest/user-guide/indexing.html | ||
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html | ||
|
||
from pytensor.graph.basic import Apply, Constant, Variable | ||
from pytensor.scalar.basic import discrete_dtypes | ||
from pytensor.tensor.basic import as_tensor | ||
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice | ||
from pytensor.xtensor.basic import XOp | ||
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor | ||
|
||
|
||
def as_idx_variable(idx): | ||
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)): | ||
raise TypeError( | ||
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead" | ||
) | ||
if isinstance(idx, slice): | ||
idx = make_slice(idx) | ||
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): | ||
pass | ||
else: | ||
# Must be integer indices, we already counted for None and slices | ||
try: | ||
idx = as_tensor(idx) | ||
except TypeError: | ||
idx = as_xtensor(idx) | ||
if idx.type.dtype == "bool": | ||
raise NotImplementedError("Boolean indexing not yet supported") | ||
if idx.type.dtype not in discrete_dtypes: | ||
raise TypeError("Numerical indices must be integers or boolean") | ||
if idx.type.dtype == "bool" and idx.type.ndim == 0: | ||
# This can't be triggered right now, but will once we lift the boolean restriction | ||
raise NotImplementedError("Scalar boolean indices not supported") | ||
return idx | ||
|
||
|
||
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None: | ||
if dim_length is None: | ||
return None | ||
if isinstance(slc, Constant): | ||
d = slc.data | ||
start, stop, step = d.start, d.stop, d.step | ||
elif slc.owner is None: | ||
# It's a root variable no way of knowing what we're getting | ||
return None | ||
else: | ||
# It's a MakeSliceOp | ||
start, stop, step = slc.owner.inputs | ||
if isinstance(start, Constant): | ||
start = start.data | ||
else: | ||
return None | ||
if isinstance(stop, Constant): | ||
stop = stop.data | ||
else: | ||
return None | ||
if isinstance(step, Constant): | ||
step = step.data | ||
else: | ||
return None | ||
return len(range(*slice(start, stop, step).indices(dim_length))) | ||
|
||
|
||
class Index(XOp): | ||
__props__ = () | ||
|
||
def make_node(self, x, *idxs): | ||
x = as_xtensor(x) | ||
idxs = [as_idx_variable(idx) for idx in idxs] | ||
|
||
x_ndim = x.type.ndim | ||
x_dims = x.type.dims | ||
x_shape = x.type.shape | ||
out_dims = [] | ||
out_shape = [] | ||
has_unlabeled_vector_idx = False | ||
has_labeled_vector_idx = False | ||
for i, idx in enumerate(idxs): | ||
if i == x_ndim: | ||
raise IndexError("Too many indices") | ||
if isinstance(idx.type, SliceType): | ||
out_dims.append(x_dims[i]) | ||
out_shape.append(get_static_slice_length(idx, x_shape[i])) | ||
elif isinstance(idx.type, XTensorType): | ||
if has_unlabeled_vector_idx: | ||
raise NotImplementedError( | ||
"Mixing of labeled and unlabeled vector indexing not implemented" | ||
) | ||
has_labeled_vector_idx = True | ||
idx_dims = idx.type.dims | ||
for dim in idx_dims: | ||
idx_dim_shape = idx.type.shape[idx_dims.index(dim)] | ||
if dim in out_dims: | ||
# Dim already introduced in output by a previous index | ||
# Update static shape or raise if incompatible | ||
out_dim_pos = out_dims.index(dim) | ||
out_dim_shape = out_shape[out_dim_pos] | ||
if out_dim_shape is None: | ||
# We don't know the size of the dimension yet | ||
out_shape[out_dim_pos] = idx_dim_shape | ||
elif ( | ||
idx_dim_shape is not None and idx_dim_shape != out_dim_shape | ||
): | ||
raise IndexError( | ||
f"Dimension of indexers mismatch for dim {dim}" | ||
) | ||
else: | ||
# New dimension | ||
out_dims.append(dim) | ||
out_shape.append(idx_dim_shape) | ||
|
||
else: # TensorType | ||
if idx.type.ndim == 0: | ||
# Scalar, dimension is dropped | ||
pass | ||
elif idx.type.ndim == 1: | ||
if has_labeled_vector_idx: | ||
raise NotImplementedError( | ||
"Mixing of labeled and unlabeled vector indexing not implemented" | ||
) | ||
has_unlabeled_vector_idx = True | ||
out_dims.append(x_dims[i]) | ||
out_shape.append(idx.type.shape[0]) | ||
else: | ||
# Same error that xarray raises | ||
raise IndexError( | ||
"Unlabeled multi-dimensional array cannot be used for indexing" | ||
) | ||
for j in range(i + 1, x_ndim): | ||
# Add any unindexed dimensions | ||
out_dims.append(x_dims[j]) | ||
out_shape.append(x_shape[j]) | ||
|
||
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) | ||
return Apply(self, [x, *idxs], [output]) | ||
|
||
|
||
index = Index() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import pytensor.xtensor.rewriting.basic | ||
import pytensor.xtensor.rewriting.indexing | ||
import pytensor.xtensor.rewriting.reduction | ||
import pytensor.xtensor.rewriting.shape | ||
import pytensor.xtensor.rewriting.vectorization |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from pytensor.graph import Constant, node_rewriter | ||
from pytensor.tensor import TensorType, specify_shape | ||
from pytensor.tensor.type_other import NoneTypeT, SliceType | ||
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor | ||
from pytensor.xtensor.indexing import Index | ||
from pytensor.xtensor.rewriting.utils import register_xcanonicalize | ||
from pytensor.xtensor.type import XTensorType | ||
|
||
|
||
def to_basic_idx(idx): | ||
if isinstance(idx.type, SliceType): | ||
if isinstance(idx, Constant): | ||
return idx.data | ||
elif idx.owner: | ||
# MakeSlice Op | ||
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible | ||
return slice( | ||
*[ | ||
None if isinstance(i.type, NoneTypeT) else i | ||
for i in idx.owner.inputs | ||
] | ||
) | ||
else: | ||
return idx | ||
if ( | ||
isinstance(idx.type, XTensorType | TensorType) | ||
and idx.type.ndim == 0 | ||
and idx.type.dtype != bool | ||
): | ||
return idx | ||
raise TypeError("Cannot convert idx to basic idx") | ||
|
||
|
||
def _count_idx_types(idxs): | ||
basic, vector, xvector = 0, 0, 0 | ||
for idx in idxs: | ||
if isinstance(idx.type, SliceType): | ||
basic += 1 | ||
elif idx.type.ndim == 0: | ||
basic += 1 | ||
elif isinstance(idx.type, TensorType): | ||
vector += 1 | ||
else: | ||
xvector += 1 | ||
return basic, vector, xvector | ||
|
||
|
||
@register_xcanonicalize | ||
@node_rewriter(tracks=[Index]) | ||
def lower_index(fgraph, node): | ||
x, *idxs = node.inputs | ||
[out] = node.outputs | ||
x_tensor = tensor_from_xtensor(x) | ||
n_basic, n_vector, n_xvector = _count_idx_types(idxs) | ||
if n_xvector == 0 and n_vector == 0: | ||
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)] | ||
elif n_vector == 1 and n_xvector == 0: | ||
# Special case for single vector index, no orthogonal indexing | ||
x_tensor_indexed = x_tensor[tuple(idxs)] | ||
else: | ||
# Not yet implemented | ||
return None | ||
|
||
# Add lost shape if any | ||
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape) | ||
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims) | ||
return [new_out] |
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -73,6 +73,67 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) | |||||||||||||||
return y | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
def expand_ellipsis( | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this check that there's at most one dims and raise otherwise? |
||||||||||||||||
dims: tuple[str, ...], all_dims: tuple[str, ...] | ||||||||||||||||
) -> tuple[str, ...]: | ||||||||||||||||
"""Expand ellipsis in dimension permutation. | ||||||||||||||||
|
||||||||||||||||
Parameters | ||||||||||||||||
---------- | ||||||||||||||||
dims : tuple[str, ...] | ||||||||||||||||
The dimension permutation, which may contain ellipsis | ||||||||||||||||
all_dims : tuple[str, ...] | ||||||||||||||||
All available dimensions | ||||||||||||||||
|
||||||||||||||||
Returns | ||||||||||||||||
------- | ||||||||||||||||
tuple[str, ...] | ||||||||||||||||
The expanded dimension permutation | ||||||||||||||||
""" | ||||||||||||||||
if dims == () or dims == (...,): | ||||||||||||||||
return tuple(reversed(all_dims)) | ||||||||||||||||
|
||||||||||||||||
if ... not in dims: | ||||||||||||||||
return dims | ||||||||||||||||
|
||||||||||||||||
pre = [] | ||||||||||||||||
post = [] | ||||||||||||||||
found = False | ||||||||||||||||
for d in dims: | ||||||||||||||||
AllenDowney marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||
if d is ...: | ||||||||||||||||
found = True | ||||||||||||||||
elif not found: | ||||||||||||||||
pre.append(d) | ||||||||||||||||
else: | ||||||||||||||||
post.append(d) | ||||||||||||||||
middle = [d for d in all_dims if d not in pre + post] | ||||||||||||||||
return tuple(pre + middle + post) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class Transpose(XOp): | ||||||||||||||||
__props__ = ("dims",) | ||||||||||||||||
|
||||||||||||||||
def __init__(self, dims: tuple[str, ...]): | ||||||||||||||||
|
def __init__(self, dims: tuple[str, ...]): | |
def __init__(self, dims: tuple[str | Ellipsis, ...]): |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xarray has a missing_dims
, that we could provide here as well: https://docs.xarray.dev/en/latest/generated/xarray.DataArray.transpose.html
We already do that for isel
:
pytensor/pytensor/xtensor/type.py
Lines 362 to 368 in e32d865
def isel( | |
self, | |
indexers: dict[str, Any] | None = None, | |
drop: bool = False, # Unused by PyTensor | |
missing_dims: Literal["raise", "warn", "ignore"] = "raise", | |
**indexers_kwargs, | |
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to expand_ellipsis again, you have the ground truth in
node.outputs[0].type.dims
that you already computed inmake_node