diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 6891823576..68f7b8c54e 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -385,7 +385,9 @@ def make_node(self, rng, size, *dist_params): dist_params = explicit_expand_dims( dist_params, self.ndims_params, - size_length=None if NoneConst.equals(size) else get_vector_length(size), + size_length=None + if isinstance(size.type, NoneTypeT) + else get_vector_length(size), ) inputs = (rng, size, *dist_params) diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index d67a6653f4..8b2dd3d0a1 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -9,7 +9,7 @@ dfs_rewriter, node_rewriter, ) -from pytensor.tensor import NoneConst, TensorVariable +from pytensor.tensor import TensorVariable from pytensor.tensor.basic import constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.extra_ops import broadcast_to @@ -20,7 +20,7 @@ AdvancedSubtensor, AdvancedSubtensor1, Subtensor, - get_idx_list, + indices_from_subtensor, ) from pytensor.tensor.type import integer_dtypes from pytensor.tensor.type_other import NoneTypeT, SliceType @@ -237,17 +237,20 @@ def is_nd_advanced_idx(idx, dtype) -> bool: return False # Parse indices - indices = get_idx_list(node.inputs, getattr(subtensor_op, "idx_list", None)) - - # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) - # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis). - # If we wanted to support that we could rewrite it as subtensor + dimshuffle - # and make use of the dimshuffle lift rewrite - if any( - is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx) - for idx in indices - ): - return False + if isinstance(subtensor_op, Subtensor): + indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list) + else: + indices = node.inputs[1:] + # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) + # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis). + # If we wanted to support that we could rewrite it as subtensor + dimshuffle + # and make use of the dimshuffle lift rewrite + # TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem + if any( + is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT) + for idx in indices + ): + return False # Check that indexing does not act on support dims batch_ndims = rv_op.batch_ndim(rv_node) @@ -267,7 +270,7 @@ def is_nd_advanced_idx(idx, dtype) -> bool: for idx in supp_indices: if not ( isinstance(idx.type, SliceType) - and all(NoneConst.equals(i) for i in idx.owner.inputs) + and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs) ): return False n_discarded_idxs = len(supp_indices) diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 86628a81cb..6f2700f55f 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -7,7 +7,7 @@ import numpy as np from pytensor.compile.sharedvalue import shared -from pytensor.graph.basic import Constant, Variable +from pytensor.graph.basic import Variable from pytensor.scalar import ScalarVariable from pytensor.tensor import NoneConst, get_vector_length from pytensor.tensor.basic import as_tensor_variable, cast @@ -15,6 +15,7 @@ from pytensor.tensor.math import maximum from pytensor.tensor.shape import shape_padleft, specify_shape from pytensor.tensor.type import int_dtypes +from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.utils import faster_broadcast_to from pytensor.tensor.variable import TensorVariable @@ -178,24 +179,26 @@ def normalize_size_param( shape: int | np.ndarray | Variable | Sequence | None, ) -> Variable: """Create an PyTensor value for a ``RandomVariable`` ``size`` parameter.""" - if shape is None or NoneConst.equals(shape): + if shape is None: return NoneConst - elif isinstance(shape, int): + if isinstance(shape, Variable) and isinstance(shape.type, NoneTypeT): + return shape + + if isinstance(shape, int): shape = as_tensor_variable([shape], ndim=1) - elif not isinstance(shape, np.ndarray | Variable | Sequence): - raise TypeError( - "Parameter size must be None, an integer, or a sequence with integers." - ) else: + if not isinstance(shape, Sequence | Variable | np.ndarray): + raise TypeError( + "Parameter size must be None, an integer, or a sequence with integers." + ) shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64") - if not isinstance(shape, Constant): + if shape.type.shape == (None,): # This should help ensure that the length of non-constant `size`s - # will be available after certain types of cloning (e.g. the kind - # `Scan` performs) + # will be available after certain types of cloning (e.g. the kind `Scan` performs) shape = specify_shape(shape, (get_vector_length(shape),)) - assert not any(s is None for s in shape.type.shape) + assert shape.type.shape != (None,) assert shape.dtype in int_dtypes return shape diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 5dd2859147..af953c79fd 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -47,7 +47,7 @@ ) from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes -from pytensor.tensor.type_other import NoneConst, NoneTypeT +from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable @@ -1137,7 +1137,7 @@ def local_merge_consecutive_specify_shape(fgraph, node): inner_obj, *shape = obj.owner.inputs for dim, sh in enumerate(node.inputs[1:]): - if not NoneConst.equals(sh): + if not isinstance(sh.type, NoneTypeT): shape[dim] = sh # TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are @@ -1183,7 +1183,7 @@ def local_Shape_of_SpecifyShape(fgraph, node): # Replace `NoneConst` by `shape_i` for i, sh in enumerate(shape): - if NoneConst.equals(sh): + if isinstance(sh.type, NoneTypeT): shape[i] = x.shape[i] return [stack(shape).astype(np.int64)] @@ -1219,7 +1219,7 @@ def local_specify_shape_lift(fgraph, node): for i, (dim, bcast) in enumerate( zip(shape, out_broadcastable, strict=True) ) - if (not bcast and not NoneConst.equals(dim)) + if (not bcast and not isinstance(dim.type, NoneTypeT)) } new_elem_inps = elem_inps.copy() for i, elem_inp in enumerate(elem_inps): diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 78ab8864c9..bdac2a6c2b 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -408,7 +408,9 @@ def make_node(self, x, *shape): shape = tuple( NoneConst - if (s is None or NoneConst.equals(s)) + if ( + s is None or (isinstance(s, Variable) and isinstance(s.type, NoneTypeT)) + ) else ptb.as_tensor_variable(s, ndim=0) for s in shape ) @@ -506,7 +508,7 @@ def c_code(self, node, name, i_names, o_names, sub): for i, (shp_name, shp) in enumerate( zip(shape_names, node.inputs[1:], strict=True) ): - if NoneConst.equals(shp): + if isinstance(shp.type, NoneTypeT): continue code += dedent( f""" @@ -594,7 +596,10 @@ def _vectorize_specify_shape(op, node, x, *shape): if any( as_tensor_variable(dim).type.ndim != 0 for dim in shape - if not (NoneConst.equals(dim) or dim is None) + if not ( + (isinstance(dim, Variable) and isinstance(dim.type, NoneTypeT)) + or dim is None + ) ): raise NotImplementedError( "It is not possible to vectorize the shape argument of SpecifyShape" diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index adf42e0550..4fd4705390 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -11,6 +11,7 @@ from pytensor.tensor.random.op import RandomVariable, default_rng from pytensor.tensor.shape import specify_shape from pytensor.tensor.type import iscalar, tensor +from pytensor.tensor.type_other import none_type_t @pytest.fixture(scope="function", autouse=False) @@ -317,3 +318,12 @@ def test_size_none_vs_empty(): ValueError, match="Size length is incompatible with batched dimensions" ): rv([0], [1], size=()) + + +def test_non_constant_none_size(): + # Regression test for https://github.com/pymc-devs/pymc/issues/7901#issuecomment-3528479876 + loc = pt.vector("loc", dtype="float64") + size = none_type_t("none_size") + + rv = normal(loc, size=size) + rv.eval({loc: np.arange(5, dtype="float64"), size: None}, mode="FAST_COMPILE") diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index aa761d2922..9ad0bf2dc9 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -7,9 +7,11 @@ from pytensor.tensor.random.utils import ( RandomStream, broadcast_params, + normalize_size_param, supp_shape_from_ref_param_shape, ) -from pytensor.tensor.type import matrix, tensor +from pytensor.tensor.type import TensorType, matrix, tensor +from pytensor.tensor.type_other import NoneTypeT, none_type_t from tests import unittest_tools as utt @@ -327,3 +329,22 @@ def test_supp_shape_from_ref_param_shape(): ref_param_idx=1, ) assert res == (3, 4) + + +def test_normalize_size_param(): + assert normalize_size_param(None).type == NoneTypeT() + + sym_none_size = none_type_t() + assert normalize_size_param(sym_none_size) is sym_none_size + + empty_size = normalize_size_param(()) + assert empty_size.type == TensorType(dtype="int64", shape=(0,)) + + int_size = normalize_size_param(5) + assert int_size.type == TensorType(dtype="int64", shape=(1,)) + + seq_int_size = normalize_size_param((5, 3, 4)) + assert seq_int_size.type == TensorType(dtype="int64", shape=(3,)) + + sym_tensor_size = tensor(shape=(3,), dtype="int64") + assert normalize_size_param(sym_tensor_size) is sym_tensor_size