Skip to content

Commit 8719bb1

Browse files
committed
Only require input_ndim and not input_broadcastable in DimShuffle
1 parent 509cdb9 commit 8719bb1

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

tests/tensor/rewriting/test_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum():
37493749
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)
37503750

37513751
# If a transpose is applied to the sum
3752-
transpose_op = DimShuffle(ndim=2, new_order=(1, 0))
3752+
transpose_op = DimShuffle(input_ndim=2, new_order=(1, 0))
37533753
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)
37543754

37553755
# If the sum is performed with keepdims=True

tests/tensor/test_elemwise.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,11 @@ def test_infer_shape(self):
109109
((1,), ("x", "x")),
110110
]:
111111
i_shape = [entry if entry == 1 else None for entry in xsh]
112-
ib = [(entry == 1) for entry in xsh]
113112
adtens = self.type(self.dtype, shape=i_shape)("x")
114113
adtens_val = np.ones(xsh, dtype=self.dtype)
115114
self._compile_and_check(
116115
[adtens],
117-
[self.op(ib, shuffle)(adtens)],
116+
[self.op(input_ndim=len(xsh), new_order=shuffle)(adtens)],
118117
[adtens_val],
119118
self.op,
120119
warn=False,
@@ -188,7 +187,7 @@ def test_static_shape(self):
188187
def test_valid_input_ndim(self):
189188
assert DimShuffle(input_ndim=2, new_order=(1, 0)).input_ndim == 2
190189

191-
with pytest.raises(TypeError, match="input_ndim must an integer"):
190+
with pytest.raises(TypeError, match="input_ndim must be an integer"):
192191
DimShuffle(input_ndim=(True, False), new_order=(1, 0))
193192

194193

tests/tensor/test_extra_ops.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -481,12 +481,9 @@ def test_invalid_input(self):
481481
assert f([0]) == 0
482482

483483
# Test that we cannot squeeze dimensions whose length is greater than 1
484-
error_txt_1 = re.escape("SpecifyShape: Got shape (3,), expected (1,).")
485-
error_txt_2 = re.escape("SpecifyShape: dim 0 of input has shape 3, expected 1")
486-
match = error_txt_1 if pytensor.config.mode == "FAST_COMPILE" else error_txt_2
487484
with pytest.raises(
488-
AssertionError,
489-
match=match,
485+
ValueError,
486+
match="cannot reshape array of size 3 into shape ()",
490487
):
491488
f([0, 1, 2])
492489

0 commit comments

Comments
 (0)