Skip to content

Commit be2fb6b

Browse files
committed
Make non-strict zip strict in tensor/random/utils
1 parent dfdaeab commit be2fb6b

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

pytensor/tensor/random/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def explicit_expand_dims(
141141

142142
batch_dims = [
143143
param.type.ndim - ndim_param
144-
for param, ndim_param in zip(params, ndim_params, strict=False)
144+
for param, ndim_param in zip(params, ndim_params, strict=True)
145145
]
146146

147147
if size_length is not None:

tests/tensor/random/test_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,16 @@ def test_RandomVariable_basics(strict_test_value_flags):
7474
# `dtype` is respected
7575
rv = RandomVariable("normal", signature="(),()->()", dtype="int32")
7676
with config.change_flags(compute_test_value="off"):
77-
rv_out = rv()
77+
rv_out = rv(0, 0)
7878
assert rv_out.dtype == "int32"
79-
rv_out = rv(dtype="int64")
79+
rv_out = rv(0, 0, dtype="int64")
8080
assert rv_out.dtype == "int64"
8181

8282
with pytest.raises(
8383
ValueError,
8484
match="Cannot change the dtype of a normal RV from int32 to float32",
8585
):
86-
assert rv(dtype="float32").dtype == "float32"
86+
assert rv(0, 0, dtype="float32").dtype == "float32"
8787

8888

8989
def test_RandomVariable_bcast(strict_test_value_flags):

0 commit comments

Comments
 (0)