diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index ffa27e5d5a..f80dfaaf5c 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -489,7 +489,6 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "BlasOpt", "fusion", "inplace", - "local_uint_constant_indices", "scan_save_mem_prealloc", ], ), diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 1af10e52b4..defb72bfbc 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -5,7 +5,6 @@ import numpy as np import pytensor -import pytensor.scalar.basic as ps from pytensor import compile from pytensor.compile import optdb from pytensor.graph.basic import Constant, Variable @@ -14,8 +13,11 @@ copy_stack_trace, in2out, node_rewriter, + out2in, ) from pytensor.raise_op import Assert +from pytensor.scalar import Add, ScalarConstant, ScalarType +from pytensor.scalar import constant as scalar_constant from pytensor.tensor.basic import ( Alloc, Join, @@ -31,6 +33,7 @@ register_infer_shape, switch, ) +from pytensor.tensor.basic import constant as tensor_constant from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -588,11 +591,11 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): remove_dim = [] node_inputs_idx = 1 for dim, elem in enumerate(idx): - if isinstance(elem, (ps.ScalarType)): + if isinstance(elem, ScalarType): # The idx is a ScalarType, ie a Type. This means the actual index # is contained in node.inputs[1] dim_index = node.inputs[node_inputs_idx] - if isinstance(dim_index, ps.ScalarConstant): + if isinstance(dim_index, ScalarConstant): dim_index = dim_index.value if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]: remove_dim.append(dim) @@ -770,7 +773,7 @@ def local_subtensor_make_vector(fgraph, node): (idx,) = idxs - if isinstance(idx, ps.ScalarType | TensorType): + if isinstance(idx, ScalarType | TensorType): old_idx, idx = idx, node.inputs[1] assert idx.type.is_super(old_idx) elif isinstance(node.op, AdvancedSubtensor1): @@ -895,7 +898,7 @@ def local_set_to_inc_subtensor(fgraph, node): and node.op.set_instead_of_inc and node.inputs[1].owner and isinstance(node.inputs[1].owner.op, Elemwise) - and isinstance(node.inputs[1].owner.op.scalar_op, ps.Add) + and isinstance(node.inputs[1].owner.op.scalar_op, Add) ): addn = node.inputs[1].owner subn = None @@ -1789,7 +1792,6 @@ def local_join_subtensors(fgraph, node): return [merged_subtensors] -@register_specialize @node_rewriter( [ Subtensor, @@ -1850,12 +1852,10 @@ def local_uint_constant_indices(fgraph, node): if dtype == index_val.dtype: continue - if index_val.ndim > 0: - new_index = pytensor.tensor.as_tensor_variable( - index_val.astype(dtype), dtype=dtype - ) + if isinstance(index.type, TensorType): + new_index = tensor_constant(index_val.astype(dtype), dtype=dtype) else: - new_index = ps.constant(index_val.astype(dtype), dtype=dtype) + new_index = scalar_constant(index_val.astype(dtype), dtype=dtype) new_indices[i] = new_index has_new_index = True @@ -1877,6 +1877,20 @@ def local_uint_constant_indices(fgraph, node): return [new_out] +compile.optdb.register( + local_uint_constant_indices.__name__, + out2in(local_uint_constant_indices), + # We don't include in the Python / C because those always cast indices to int64 internally. + "numba", + "jax", + # After specialization and uncanonicalization + # Other rewrites don't worry about the dtype of the indices + # And can cause unnecessary passes of this optimization + # Such as x.shape[np.int(0)] -> x.shape[np.uint(0)] + position=4, +) + + @register_canonicalize("shape_unsafe") @register_stabilize("shape_unsafe") @register_specialize("shape_unsafe") diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 8e3e5cb902..9c14f31e1d 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3,7 +3,6 @@ import warnings from collections.abc import Callable, Iterable, Sequence from itertools import chain, groupby -from textwrap import dedent from typing import cast, overload import numpy as np @@ -19,7 +18,7 @@ from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2 +from pytensor.npy_2_compat import numpy_version, using_numpy_2 from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.tensor import ( @@ -2130,24 +2129,6 @@ def perform(self, node, inp, out_): else: o = None - # If i.dtype is more precise than numpy.intp (int32 on 32-bit machines, - # int64 on 64-bit machines), numpy may raise the following error: - # TypeError: array cannot be safely cast to required type. - # We need to check if values in i can fit in numpy.intp, because - # if they don't, that should be an error (no array can have that - # many elements on a 32-bit arch). - if i.dtype != np.intp: - i_ = np.asarray(i, dtype=np.intp) - if not np.can_cast(i.dtype, np.intp): - # Check if there was actually an incorrect conversion - if np.any(i != i_): - raise IndexError( - "index contains values that are bigger " - "than the maximum array size on this system.", - i, - ) - i = i_ - out[0] = x.take(i, axis=0, out=o) def connection_pattern(self, node): @@ -2187,16 +2168,6 @@ def infer_shape(self, fgraph, node, ishapes): x, ilist = ishapes return [ilist + x[1:]] - def c_support_code(self, **kwargs): - # In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG, - # which is not defined. It should be NPY_MIN_LONG instead in that case. - return npy_2_compat_header() + dedent( - """\ - #ifndef MIN_LONG - #define MIN_LONG NPY_MIN_LONG - #endif""" - ) - def c_code(self, node, name, input_names, output_names, sub): if self.__class__ is not AdvancedSubtensor1: raise MethodNotDefined( @@ -2207,61 +2178,16 @@ def c_code(self, node, name, input_names, output_names, sub): output_name = output_names[0] fail = sub["fail"] return f""" - PyArrayObject *indices; - int i_type = PyArray_TYPE({i_name}); - if (i_type != NPY_INTP) {{ - // Cast {i_name} to NPY_INTP (expected by PyArray_TakeFrom), - // if all values fit. - if (!PyArray_CanCastSafely(i_type, NPY_INTP) && - PyArray_SIZE({i_name}) > 0) {{ - npy_int64 min_val, max_val; - PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS, - NULL); - if (py_min_val == NULL) {{ - {fail}; - }} - min_val = PyLong_AsLongLong(py_min_val); - Py_DECREF(py_min_val); - if (min_val == -1 && PyErr_Occurred()) {{ - {fail}; - }} - PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS, - NULL); - if (py_max_val == NULL) {{ - {fail}; - }} - max_val = PyLong_AsLongLong(py_max_val); - Py_DECREF(py_max_val); - if (max_val == -1 && PyErr_Occurred()) {{ - {fail}; - }} - if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {{ - PyErr_SetString(PyExc_IndexError, - "Index contains values " - "that are bigger than the maximum array " - "size on this system."); - {fail}; - }} - }} - indices = (PyArrayObject*) PyArray_Cast({i_name}, NPY_INTP); - if (indices == NULL) {{ - {fail}; - }} - }} - else {{ - indices = {i_name}; - Py_INCREF(indices); - }} if ({output_name} != NULL) {{ npy_intp nd, i, *shape; - nd = PyArray_NDIM({a_name}) + PyArray_NDIM(indices) - 1; + nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1; if (PyArray_NDIM({output_name}) != nd) {{ Py_CLEAR({output_name}); }} else {{ shape = PyArray_DIMS({output_name}); - for (i = 0; i < PyArray_NDIM(indices); i++) {{ - if (shape[i] != PyArray_DIMS(indices)[i]) {{ + for (i = 0; i < PyArray_NDIM({i_name}); i++) {{ + if (shape[i] != PyArray_DIMS({i_name})[i]) {{ Py_CLEAR({output_name}); break; }} @@ -2269,7 +2195,7 @@ def c_code(self, node, name, input_names, output_names, sub): if ({output_name} != NULL) {{ for (; i < nd; i++) {{ if (shape[i] != PyArray_DIMS({a_name})[ - i-PyArray_NDIM(indices)+1]) {{ + i-PyArray_NDIM({i_name})+1]) {{ Py_CLEAR({output_name}); break; }} @@ -2278,13 +2204,12 @@ def c_code(self, node, name, input_names, output_names, sub): }} }} {output_name} = (PyArrayObject*)PyArray_TakeFrom( - {a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE); - Py_DECREF(indices); + {a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE); if ({output_name} == NULL) {fail}; """ def c_code_cache_version(self): - return (0, 1, 2, 3) + return (4,) advanced_subtensor1 = AdvancedSubtensor1()