Skip to content

Commit fe5865e

Browse files
committed
Remove assert in local_useless_alloc
Rewrite was already tagged as "shape_unsafe"
1 parent c855a6d commit fe5865e

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@
6767
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6868
from pytensor.tensor.exceptions import NotScalarConstantError
6969
from pytensor.tensor.extra_ops import broadcast_arrays
70-
from pytensor.tensor.math import Sum, add
71-
from pytensor.tensor.math import all as at_all
72-
from pytensor.tensor.math import eq
70+
from pytensor.tensor.math import Sum, add, eq
7371
from pytensor.tensor.shape import Shape_i, shape_padleft
7472
from pytensor.tensor.sort import TopKOp
7573
from pytensor.tensor.type import DenseTensorType, TensorType
@@ -266,6 +264,7 @@ def local_elemwise_alloc(fgraph, node):
266264
introduces them as a canonicalization of `Alloc`'s with leading
267265
broadcastable dimensions.
268266
"""
267+
# This is handled by local_alloc_unary
269268
if len(node.inputs) == 1:
270269
return None
271270

@@ -465,14 +464,7 @@ def local_useless_alloc(fgraph, node):
465464
inp.type.dtype == output.type.dtype
466465
and inp.type.broadcastable == output.type.broadcastable
467466
):
468-
if inp.ndim == 0:
469-
return [inp]
470-
else:
471-
return [
472-
Assert("Shapes must be equal")(
473-
inp, at_all(eq(inp.shape, node.inputs[1:]))
474-
)
475-
]
467+
return [inp]
476468

477469

478470
@register_specialize

tests/tensor/rewriting/test_basic.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,21 +272,36 @@ class TestLocalCanonicalizeAlloc:
272272
def setup_method(self):
273273
self.rng = np.random.default_rng(utt.fetch_seed())
274274

275-
def test_inconsistent_shared(self):
275+
@pytest.mark.parametrize("shape_unsafe", (True, False))
276+
def test_inconsistent_shared(self, shape_unsafe):
276277
# These shapes don't match!
277278
x = shared(self.rng.standard_normal((3, 7)))
278279
a = at.alloc(x, 6, 7)
279280

280281
assert a.owner and isinstance(a.owner.op, Alloc)
281282

282-
f = function([], a, mode=rewrite_mode)
283+
mode = rewrite_mode if shape_unsafe else rewrite_mode.excluding("shape_unsafe")
284+
f = function([], a, mode=mode)
283285

284-
# The rewrite should then be applied, and remove Alloc
285-
assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort())
286-
assert any(isinstance(node.op, Assert) for node in f.maker.fgraph.toposort())
287-
288-
with pytest.raises(AssertionError):
289-
f()
286+
has_alloc = any(
287+
isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()
288+
)
289+
if shape_unsafe:
290+
assert not has_alloc
291+
# Error raised by SpecifyShape that is introduced due to static shape inference
292+
with pytest.raises(
293+
AssertionError,
294+
match="SpecifyShape: dim 0 of input has shape 3, expected 6.",
295+
):
296+
f()
297+
else:
298+
assert has_alloc
299+
# Error raised by Alloc Op
300+
with pytest.raises(
301+
ValueError,
302+
match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)",
303+
):
304+
f()
290305

291306
good_x_val = self.rng.standard_normal((6, 7))
292307
x.set_value(good_x_val)

0 commit comments

Comments
 (0)