66import pytensor
77from pytensor import compile
88from pytensor .compile import optdb
9- from pytensor .graph .basic import Constant , Variable
9+ from pytensor .graph .basic import Constant , Variable , equal_computations
1010from pytensor .graph .rewriting .basic import (
1111 WalkingGraphRewriter ,
1212 copy_stack_trace ,
1515 node_rewriter ,
1616)
1717from pytensor .raise_op import Assert
18- from pytensor .scalar import Add , ScalarConstant , ScalarType
18+ from pytensor .scalar import Add , ScalarConstant
1919from pytensor .scalar import constant as scalar_constant
2020from pytensor .tensor .basic import (
2121 Alloc ,
7272 AdvancedSubtensor1 ,
7373 IncSubtensor ,
7474 Subtensor ,
75+ _is_position ,
7576 advanced_inc_subtensor1 ,
7677 advanced_subtensor1 ,
7778 as_index_constant ,
@@ -480,9 +481,8 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
480481 remove_dim = []
481482 node_inputs_idx = 1
482483 for dim , elem in enumerate (idx ):
483- if isinstance (elem , ScalarType ):
484- # The idx is a ScalarType, ie a Type. This means the actual index
485- # is contained in node.inputs[1]
484+ if _is_position (elem ):
485+ # The idx is a integer position.
486486 dim_index = node .inputs [node_inputs_idx ]
487487 if isinstance (dim_index , ScalarConstant ):
488488 dim_index = dim_index .value
@@ -494,9 +494,6 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
494494 elif isinstance (elem , slice ):
495495 if elem != slice (None ):
496496 return
497- elif isinstance (elem , int | np .integer ):
498- if elem in (0 , - 1 ) and node .inputs [0 ].broadcastable [dim ]:
499- remove_dim .append (dim )
500497 else :
501498 raise TypeError ("case not expected" )
502499
@@ -508,6 +505,39 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
508505 return [node .inputs [0 ].dimshuffle (tuple (remain_dim ))]
509506
510507
508+ def _idx_list_struct_equal (idx_list1 , idx_list2 ):
509+ """Check if two idx_lists have the same structure.
510+
511+ Positions (integers) are treated as equivalent regardless of value,
512+ since positions are relative to each Op's inputs.
513+ """
514+ if len (idx_list1 ) != len (idx_list2 ):
515+ return False
516+
517+ def normalize_entry (entry ):
518+ if isinstance (entry , int ) and not isinstance (entry , bool ):
519+ return "POS" # All positions are equivalent
520+ elif isinstance (entry , slice ):
521+ return (
522+ "POS"
523+ if isinstance (entry .start , int ) and not isinstance (entry .start , bool )
524+ else entry .start ,
525+ "POS"
526+ if isinstance (entry .stop , int ) and not isinstance (entry .stop , bool )
527+ else entry .stop ,
528+ "POS"
529+ if isinstance (entry .step , int ) and not isinstance (entry .step , bool )
530+ else entry .step ,
531+ )
532+ else :
533+ return entry
534+
535+ for e1 , e2 in zip (idx_list1 , idx_list2 ):
536+ if normalize_entry (e1 ) != normalize_entry (e2 ):
537+ return False
538+ return True
539+
540+
511541@register_specialize
512542@register_canonicalize
513543@node_rewriter ([Subtensor ])
@@ -523,9 +553,17 @@ def local_subtensor_inc_subtensor(fgraph, node):
523553 if not x .owner .op .set_instead_of_inc :
524554 return
525555
526- if x .owner .inputs [2 :] == node .inputs [1 :] and tuple (
527- x .owner .op .idx_list
528- ) == tuple (node .op .idx_list ):
556+ # Check structural equality of idx_lists and semantic equality of inputs
557+ inc_inputs = x .owner .inputs [2 :]
558+ sub_inputs = node .inputs [1 :]
559+
560+ if (
561+ len (inc_inputs ) == len (sub_inputs )
562+ and _idx_list_struct_equal (x .owner .op .idx_list , node .op .idx_list )
563+ and all (
564+ equal_computations ([a ], [b ]) for a , b in zip (inc_inputs , sub_inputs )
565+ )
566+ ):
529567 out = node .outputs [0 ]
530568 y = x .owner .inputs [1 ]
531569 # If the dtypes differ, cast y into x.dtype
0 commit comments