Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pytensor/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ def parse_bcast_and_shape(s):
self.name = name
self.numpy_dtype = np.dtype(self.dtype)

def __call__(self, *args, shape=None, **kwargs):
if shape is not None:
# Check if shape is compatible with the original type
new_type = self.clone(shape=shape)
if self.is_super(new_type):
return new_type(*args, **kwargs)
else:
raise ValueError(
f"{shape=} is incompatible with original type shape {self.shape=}"
)
return super().__call__(*args, **kwargs)

def clone(
self, dtype=None, shape=None, broadcastable=None, **kwargs
) -> "TensorType":
Expand Down
22 changes: 22 additions & 0 deletions tests/tensor/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from pytensor.tensor.type import (
TensorType,
col,
dmatrix,
drow,
fmatrix,
frow,
matrix,
row,
scalar,
Expand Down Expand Up @@ -477,3 +481,21 @@ def test_row_matrix_creator_helpers(helper):
match = "The second dimension of a `col` must have shape 1, got 5"
with pytest.raises(ValueError, match=match):
helper(shape=(2, 5))


def test_shape_of_predefined_dtype_tensor():
# Valid: None dimensions can be specialized
assert fmatrix(shape=(1, None)).type == frow
assert drow(shape=(1, 5)).type == dmatrix(shape=(1, 5)).type

# Invalid: Number of dimensions must match
with pytest.raises(ValueError):
fmatrix(shape=(None, None, None))

# Invalid: Fixed shapes must match
with pytest.raises(ValueError):
fmatrix(shape=(3, 5)).type(shape=(4, 5))

# Invalid: Known shapes can't be lost
with pytest.raises(ValueError):
drow(shape=(None, None))