Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,9 +771,9 @@ def orderings(self, fgraph, ordered=True):
}
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", []
app.op, "destroyhandler_tolerate_aliased", ()
)
assert isinstance(tolerate_aliased, list)
assert isinstance(tolerate_aliased, list | tuple)
ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
}
Expand Down
19 changes: 1 addition & 18 deletions pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def subtensor(x, *ilists):


@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
Expand Down Expand Up @@ -75,24 +76,6 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
return incsubtensor


@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, indices, y):
return x.at[indices].set(y)

else:

def jax_fn(x, indices, y):
return x.at[indices].add(y)

def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)

return advancedincsubtensor


@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
Expand Down
246 changes: 197 additions & 49 deletions pytensor/link/numba/dispatch/subtensor.py

Large diffs are not rendered by default.

21 changes: 15 additions & 6 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType
from pytensor.tensor.type_other import MakeSlice


def check_negative_steps(indices):
Expand Down Expand Up @@ -64,6 +64,7 @@ def makeslice(start, stop, step):
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
indices = indices_from_subtensor(indices, op.idx_list)
check_negative_steps(indices)
return x[indices]

Expand Down Expand Up @@ -102,12 +103,14 @@ def inc_subtensor(x, y, *flattened_indices):
@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)

if op.set_instead_of_inc:

def adv_set_subtensor(x, y, *indices):
def adv_set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -120,7 +123,8 @@ def adv_set_subtensor(x, y, *indices):

elif ignore_duplicates:

def adv_inc_subtensor_no_duplicates(x, y, *indices):
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -132,13 +136,18 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices):
return adv_inc_subtensor_no_duplicates

else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
# Check if we have slice indexing in idx_list
has_slice_indexing = (
any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
)
if has_slice_indexing:
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)

def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
def adv_inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
# Not needed because slices aren't supported in this path
# check_negative_steps(indices)
if not inplace:
x = x.clone()
Expand Down
10 changes: 5 additions & 5 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type
from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
Expand Down Expand Up @@ -300,7 +300,7 @@ def _get_underlying_scalar_constant_value(
"""
from pytensor.compile.ops import DeepCopyOp, OutputGuard
from pytensor.sparse import CSM
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.subtensor import Subtensor, _is_position

v = orig_v
while True:
Expand Down Expand Up @@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
if _is_position(idx):
idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
Expand Down Expand Up @@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
and len(v.owner.op.idx_list) == 1
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
if _is_position(idx):
idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
Expand All @@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
op = owner.op
idx_list = op.idx_list
idx = idx_list[0]
if isinstance(idx, Type):
if _is_position(idx):
idx = _get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur
)
Expand Down
31 changes: 18 additions & 13 deletions pytensor/tensor/random/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,20 +237,22 @@ def is_nd_advanced_idx(idx, dtype) -> bool:
return False

# Parse indices
if isinstance(subtensor_op, Subtensor):
if isinstance(subtensor_op, Subtensor | AdvancedSubtensor):
indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list)
else:
indices = node.inputs[1:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if any(
is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
for idx in indices
):
return False

# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if any(
is_nd_advanced_idx(idx, integer_dtypes)
or isinstance(getattr(idx, "type", None), NoneTypeT)
for idx in indices
):
return False

# Check that indexing does not act on support dims
batch_ndims = rv_op.batch_ndim(rv_node)
Expand All @@ -269,8 +271,11 @@ def is_nd_advanced_idx(idx, dtype) -> bool:
)
for idx in supp_indices:
if not (
isinstance(idx.type, SliceType)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
(isinstance(idx, slice) and idx == slice(None))
or (
isinstance(getattr(idx, "type", None), SliceType)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
)
):
return False
n_discarded_idxs = len(supp_indices)
Expand Down
10 changes: 6 additions & 4 deletions pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from pytensor.graph.traversal import ancestors
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
from pytensor.scalar import ScalarType
from pytensor.tensor.basic import (
MakeVector,
as_tensor_variable,
Expand Down Expand Up @@ -45,7 +44,7 @@
SpecifyShape,
specify_shape,
)
from pytensor.tensor.subtensor import Subtensor, get_idx_list
from pytensor.tensor.subtensor import Subtensor, _is_position, get_idx_list
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -845,13 +844,16 @@ def _is_shape_i_of_x(
if isinstance(var.owner.op, Shape_i):
return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore

# Match Subtensor((ScalarType,))(Shape(input), i)
# Match Subtensor((int,))(Shape(input), i) - single integer index into shape
if isinstance(var.owner.op, Subtensor):
idx_entry = (
var.owner.op.idx_list[0] if len(var.owner.op.idx_list) == 1 else None
)
return (
# Check we have integer indexing operation
# (and not slice or multiple indexing)
len(var.owner.op.idx_list) == 1
and isinstance(var.owner.op.idx_list[0], ScalarType)
and _is_position(idx_entry)
# Check we are indexing on the shape of x
and var.owner.inputs[0].owner is not None
and isinstance(var.owner.inputs[0].owner.op, Shape)
Expand Down
Loading