55import numpy as np
66
77import pytensor
8- import pytensor .scalar .basic as ps
98from pytensor import compile
109from pytensor .compile import optdb
1110from pytensor .graph .basic import Constant , Variable
1413 copy_stack_trace ,
1514 in2out ,
1615 node_rewriter ,
16+ out2in ,
1717)
1818from pytensor .raise_op import Assert
19+ from pytensor .scalar import Add , ScalarConstant , ScalarType
20+ from pytensor .scalar import constant as scalar_constant
1921from pytensor .tensor .basic import (
2022 Alloc ,
2123 Join ,
3133 register_infer_shape ,
3234 switch ,
3335)
36+ from pytensor .tensor .basic import constant as tensor_constant
3437from pytensor .tensor .blockwise import Blockwise
3538from pytensor .tensor .elemwise import Elemwise
3639from pytensor .tensor .exceptions import NotScalarConstantError
@@ -588,11 +591,11 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
588591 remove_dim = []
589592 node_inputs_idx = 1
590593 for dim , elem in enumerate (idx ):
591- if isinstance (elem , ( ps . ScalarType ) ):
594+ if isinstance (elem , ScalarType ):
592595 # The idx is a ScalarType, ie a Type. This means the actual index
593596 # is contained in node.inputs[1]
594597 dim_index = node .inputs [node_inputs_idx ]
595- if isinstance (dim_index , ps . ScalarConstant ):
598+ if isinstance (dim_index , ScalarConstant ):
596599 dim_index = dim_index .value
597600 if dim_index in (0 , - 1 ) and node .inputs [0 ].broadcastable [dim ]:
598601 remove_dim .append (dim )
@@ -770,7 +773,7 @@ def local_subtensor_make_vector(fgraph, node):
770773
771774 (idx ,) = idxs
772775
773- if isinstance (idx , ps . ScalarType | TensorType ):
776+ if isinstance (idx , ScalarType | TensorType ):
774777 old_idx , idx = idx , node .inputs [1 ]
775778 assert idx .type .is_super (old_idx )
776779 elif isinstance (node .op , AdvancedSubtensor1 ):
@@ -895,7 +898,7 @@ def local_set_to_inc_subtensor(fgraph, node):
895898 and node .op .set_instead_of_inc
896899 and node .inputs [1 ].owner
897900 and isinstance (node .inputs [1 ].owner .op , Elemwise )
898- and isinstance (node .inputs [1 ].owner .op .scalar_op , ps . Add )
901+ and isinstance (node .inputs [1 ].owner .op .scalar_op , Add )
899902 ):
900903 addn = node .inputs [1 ].owner
901904 subn = None
@@ -1789,7 +1792,6 @@ def local_join_subtensors(fgraph, node):
17891792 return [merged_subtensors ]
17901793
17911794
1792- @register_specialize
17931795@node_rewriter (
17941796 [
17951797 Subtensor ,
@@ -1850,12 +1852,10 @@ def local_uint_constant_indices(fgraph, node):
18501852 if dtype == index_val .dtype :
18511853 continue
18521854
1853- if index_val .ndim > 0 :
1854- new_index = pytensor .tensor .as_tensor_variable (
1855- index_val .astype (dtype ), dtype = dtype
1856- )
1855+ if isinstance (index .type , TensorType ):
1856+ new_index = tensor_constant (index_val .astype (dtype ), dtype = dtype )
18571857 else :
1858- new_index = ps . constant (index_val .astype (dtype ), dtype = dtype )
1858+ new_index = scalar_constant (index_val .astype (dtype ), dtype = dtype )
18591859
18601860 new_indices [i ] = new_index
18611861 has_new_index = True
@@ -1877,6 +1877,20 @@ def local_uint_constant_indices(fgraph, node):
18771877 return [new_out ]
18781878
18791879
1880+ compile .optdb .register (
1881+ local_uint_constant_indices .__name__ ,
1882+ out2in (local_uint_constant_indices ),
1883+ # Python / C backends always cast indices to int64 internally.
1884+ "numba" ,
1885+ "jax" ,
1886+ # After specialization and uncanonicalization
1887+ # Other rewrites don't worry about the dtype of the indices
1888+ # And can cause unnecessary passes of this optimization
1889+ # Such as x.shape[np.int(0)] -> x.shape[np.uint(0)]
1890+ position = 4 ,
1891+ )
1892+
1893+
18801894@register_canonicalize ("shape_unsafe" )
18811895@register_stabilize ("shape_unsafe" )
18821896@register_specialize ("shape_unsafe" )
0 commit comments