diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 937741c4cd..d0b6b5fe0a 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -123,6 +123,18 @@ def parse_bcast_and_shape(s): self.name = name self.numpy_dtype = np.dtype(self.dtype) + def __call__(self, *args, shape=None, **kwargs): + if shape is not None: + # Check if shape is compatible with the original type + new_type = self.clone(shape=shape) + if self.is_super(new_type): + return new_type(*args, **kwargs) + else: + raise ValueError( + f"{shape=} is incompatible with original type shape {self.shape=}" + ) + return super().__call__(*args, **kwargs) + def clone( self, dtype=None, shape=None, broadcastable=None, **kwargs ) -> "TensorType": diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py index 6a0ae4f957..e9a1914067 100644 --- a/tests/tensor/test_type.py +++ b/tests/tensor/test_type.py @@ -10,6 +10,10 @@ from pytensor.tensor.type import ( TensorType, col, + dmatrix, + drow, + fmatrix, + frow, matrix, row, scalar, @@ -477,3 +481,21 @@ def test_row_matrix_creator_helpers(helper): match = "The second dimension of a `col` must have shape 1, got 5" with pytest.raises(ValueError, match=match): helper(shape=(2, 5)) + + +def test_shape_of_predefined_dtype_tensor(): + # Valid: None dimensions can be specialized + assert fmatrix(shape=(1, None)).type == frow + assert drow(shape=(1, 5)).type == dmatrix(shape=(1, 5)).type + + # Invalid: Number of dimensions must match + with pytest.raises(ValueError): + fmatrix(shape=(None, None, None)) + + # Invalid: Fixed shapes must match + with pytest.raises(ValueError): + fmatrix(shape=(3, 5)).type(shape=(4, 5)) + + # Invalid: Known shapes can't be lost + with pytest.raises(ValueError): + drow(shape=(None, None))