Skip to content

Commit b5d5ffa

Browse files
committed
Implement simpler idx_list
1 parent b12fdf6 commit b5d5ffa

File tree

9 files changed

+606
-432
lines changed

9 files changed

+606
-432
lines changed

pytensor/graph/destroyhandler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,9 +771,9 @@ def orderings(self, fgraph, ordered=True):
771771
}
772772
tolerated.add(destroyed_idx)
773773
tolerate_aliased = getattr(
774-
app.op, "destroyhandler_tolerate_aliased", []
774+
app.op, "destroyhandler_tolerate_aliased", ()
775775
)
776-
assert isinstance(tolerate_aliased, list)
776+
assert isinstance(tolerate_aliased, list | tuple)
777777
ignored = {
778778
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
779779
}

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from numba.core.pythonapi import box
1111

1212
import pytensor.link.numba.dispatch.basic as numba_basic
13-
from pytensor.graph import Type, Variable
13+
from pytensor.graph import Variable
1414
from pytensor.link.numba.cache import (
1515
compile_numba_function_src,
1616
)
@@ -29,6 +29,7 @@
2929
AdvancedSubtensor1,
3030
IncSubtensor,
3131
Subtensor,
32+
_is_position,
3233
indices_from_subtensor,
3334
)
3435
from pytensor.tensor.type_other import MakeSlice, NoneTypeT
@@ -158,7 +159,7 @@ def numba_funcify_default_subtensor(op, node, **kwargs):
158159
"""Create a Python function that assembles and uses an index on an array."""
159160

160161
def convert_indices(indices_iterator, entry):
161-
if hasattr(indices_iterator, "__next__") and isinstance(entry, Type):
162+
if hasattr(indices_iterator, "__next__") and _is_position(entry):
162163
name, var = next(indices_iterator)
163164
if var.ndim == 0 and isinstance(var.type, TensorType):
164165
return f"{name}.item()"
@@ -171,8 +172,6 @@ def convert_indices(indices_iterator, entry):
171172
)
172173
elif isinstance(entry, type(None)):
173174
return "None"
174-
elif isinstance(entry, (int, np.integer)):
175-
return str(entry)
176175
else:
177176
raise ValueError(f"Unknown index type: {entry}")
178177

pytensor/tensor/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pytensor.graph.op import Op
3030
from pytensor.graph.replace import _vectorize_node
3131
from pytensor.graph.rewriting.db import EquilibriumDB
32-
from pytensor.graph.type import HasShape, Type
32+
from pytensor.graph.type import HasShape
3333
from pytensor.link.c.op import COp
3434
from pytensor.link.c.params_type import ParamsType
3535
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
@@ -300,7 +300,7 @@ def _get_underlying_scalar_constant_value(
300300
"""
301301
from pytensor.compile.ops import DeepCopyOp, OutputGuard
302302
from pytensor.sparse import CSM
303-
from pytensor.tensor.subtensor import Subtensor
303+
from pytensor.tensor.subtensor import Subtensor, _is_position
304304

305305
v = orig_v
306306
while True:
@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
433433
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
434434
):
435435
idx = v.owner.op.idx_list[0]
436-
if isinstance(idx, Type):
436+
if _is_position(idx):
437437
idx = _get_underlying_scalar_constant_value(
438438
v.owner.inputs[1], max_recur=max_recur
439439
)
@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
467467
and len(v.owner.op.idx_list) == 1
468468
):
469469
idx = v.owner.op.idx_list[0]
470-
if isinstance(idx, Type):
470+
if _is_position(idx):
471471
idx = _get_underlying_scalar_constant_value(
472472
v.owner.inputs[1], max_recur=max_recur
473473
)
@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
488488
op = owner.op
489489
idx_list = op.idx_list
490490
idx = idx_list[0]
491-
if isinstance(idx, Type):
491+
if _is_position(idx):
492492
idx = _get_underlying_scalar_constant_value(
493493
owner.inputs[1], max_recur=max_recur
494494
)

pytensor/tensor/rewriting/shape.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818
from pytensor.graph.traversal import ancestors
1919
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
20-
from pytensor.scalar import ScalarType
2120
from pytensor.tensor.basic import (
2221
MakeVector,
2322
as_tensor_variable,
@@ -45,7 +44,7 @@
4544
SpecifyShape,
4645
specify_shape,
4746
)
48-
from pytensor.tensor.subtensor import Subtensor, get_idx_list
47+
from pytensor.tensor.subtensor import Subtensor, _is_position, get_idx_list
4948
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
5049
from pytensor.tensor.type_other import NoneTypeT
5150
from pytensor.tensor.variable import TensorVariable
@@ -845,13 +844,16 @@ def _is_shape_i_of_x(
845844
if isinstance(var.owner.op, Shape_i):
846845
return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore
847846

848-
# Match Subtensor((ScalarType,))(Shape(input), i)
847+
# Match Subtensor((int,))(Shape(input), i) - single integer index into shape
849848
if isinstance(var.owner.op, Subtensor):
849+
idx_entry = (
850+
var.owner.op.idx_list[0] if len(var.owner.op.idx_list) == 1 else None
851+
)
850852
return (
851853
# Check we have integer indexing operation
852854
# (and not slice or multiple indexing)
853855
len(var.owner.op.idx_list) == 1
854-
and isinstance(var.owner.op.idx_list[0], ScalarType)
856+
and _is_position(idx_entry)
855857
# Check we are indexing on the shape of x
856858
and var.owner.inputs[0].owner is not None
857859
and isinstance(var.owner.inputs[0].owner.op, Shape)

pytensor/tensor/rewriting/subtensor.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytensor
77
from pytensor import compile
88
from pytensor.compile import optdb
9-
from pytensor.graph.basic import Constant, Variable
9+
from pytensor.graph.basic import Constant, Variable, equal_computations
1010
from pytensor.graph.rewriting.basic import (
1111
WalkingGraphRewriter,
1212
copy_stack_trace,
@@ -15,7 +15,7 @@
1515
node_rewriter,
1616
)
1717
from pytensor.raise_op import Assert
18-
from pytensor.scalar import Add, ScalarConstant, ScalarType
18+
from pytensor.scalar import Add, ScalarConstant
1919
from pytensor.scalar import constant as scalar_constant
2020
from pytensor.tensor.basic import (
2121
Alloc,
@@ -72,6 +72,7 @@
7272
AdvancedSubtensor1,
7373
IncSubtensor,
7474
Subtensor,
75+
_is_position,
7576
advanced_inc_subtensor1,
7677
advanced_subtensor1,
7778
as_index_constant,
@@ -480,9 +481,8 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
480481
remove_dim = []
481482
node_inputs_idx = 1
482483
for dim, elem in enumerate(idx):
483-
if isinstance(elem, ScalarType):
484-
# The idx is a ScalarType, ie a Type. This means the actual index
485-
# is contained in node.inputs[1]
484+
if _is_position(elem):
485+
# The idx is a integer position.
486486
dim_index = node.inputs[node_inputs_idx]
487487
if isinstance(dim_index, ScalarConstant):
488488
dim_index = dim_index.value
@@ -494,9 +494,6 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
494494
elif isinstance(elem, slice):
495495
if elem != slice(None):
496496
return
497-
elif isinstance(elem, int | np.integer):
498-
if elem in (0, -1) and node.inputs[0].broadcastable[dim]:
499-
remove_dim.append(dim)
500497
else:
501498
raise TypeError("case not expected")
502499

@@ -508,6 +505,39 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
508505
return [node.inputs[0].dimshuffle(tuple(remain_dim))]
509506

510507

508+
def _idx_list_struct_equal(idx_list1, idx_list2):
509+
"""Check if two idx_lists have the same structure.
510+
511+
Positions (integers) are treated as equivalent regardless of value,
512+
since positions are relative to each Op's inputs.
513+
"""
514+
if len(idx_list1) != len(idx_list2):
515+
return False
516+
517+
def normalize_entry(entry):
518+
if isinstance(entry, int) and not isinstance(entry, bool):
519+
return "POS" # All positions are equivalent
520+
elif isinstance(entry, slice):
521+
return (
522+
"POS"
523+
if isinstance(entry.start, int) and not isinstance(entry.start, bool)
524+
else entry.start,
525+
"POS"
526+
if isinstance(entry.stop, int) and not isinstance(entry.stop, bool)
527+
else entry.stop,
528+
"POS"
529+
if isinstance(entry.step, int) and not isinstance(entry.step, bool)
530+
else entry.step,
531+
)
532+
else:
533+
return entry
534+
535+
for e1, e2 in zip(idx_list1, idx_list2):
536+
if normalize_entry(e1) != normalize_entry(e2):
537+
return False
538+
return True
539+
540+
511541
@register_specialize
512542
@register_canonicalize
513543
@node_rewriter([Subtensor])
@@ -523,9 +553,17 @@ def local_subtensor_inc_subtensor(fgraph, node):
523553
if not x.owner.op.set_instead_of_inc:
524554
return
525555

526-
if x.owner.inputs[2:] == node.inputs[1:] and tuple(
527-
x.owner.op.idx_list
528-
) == tuple(node.op.idx_list):
556+
# Check structural equality of idx_lists and semantic equality of inputs
557+
inc_inputs = x.owner.inputs[2:]
558+
sub_inputs = node.inputs[1:]
559+
560+
if (
561+
len(inc_inputs) == len(sub_inputs)
562+
and _idx_list_struct_equal(x.owner.op.idx_list, node.op.idx_list)
563+
and all(
564+
equal_computations([a], [b]) for a, b in zip(inc_inputs, sub_inputs)
565+
)
566+
):
529567
out = node.outputs[0]
530568
y = x.owner.inputs[1]
531569
# If the dtypes differ, cast y into x.dtype

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from pytensor.compile import optdb
99
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
1010
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
11-
from pytensor.scalar import basic as ps
1211
from pytensor.tensor.basic import (
1312
Alloc,
1413
Join,
@@ -42,6 +41,7 @@
4241
AdvancedSubtensor,
4342
AdvancedSubtensor1,
4443
Subtensor,
44+
_is_position,
4545
_non_consecutive_adv_indexing,
4646
as_index_literal,
4747
get_canonical_form_slice,
@@ -702,13 +702,13 @@ def local_subtensor_make_vector(fgraph, node):
702702

703703
(idx,) = idxs
704704

705-
if isinstance(idx, ps.ScalarType | TensorType):
706-
old_idx, idx = idx, node.inputs[1]
707-
assert idx.type.is_super(old_idx)
705+
if _is_position(idx):
706+
# idx is an integer position - get the actual index value from inputs
707+
idx = node.inputs[1]
708708
elif isinstance(node.op, AdvancedSubtensor1):
709709
idx = node.inputs[1]
710710

711-
if isinstance(idx, int | np.integer):
711+
if False: # isinstance(idx, int | np.integer) - disabled, positions handled above
712712
return [x.owner.inputs[idx]]
713713
elif isinstance(idx, Variable):
714714
if idx.ndim == 0:

0 commit comments

Comments
 (0)