Skip to content

Commit f9a7865

Browse files
committed
Fix local_blockwise_alloc rewrite
The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left.
1 parent c01d61d commit f9a7865

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytensor/tensor/rewriting/blockwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
127127
value, *shape = inp.owner.inputs
128128

129129
# Check what to do with the value of the Alloc
130-
squeezed_value = _squeeze_left(value, batch_ndim)
131-
missing_ndim = len(shape) - value.type.ndim
130+
missing_ndim = inp.type.ndim - value.type.ndim
131+
squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
132132
if (
133133
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
134134
!= inp.type.broadcastable[batch_ndim:]

0 commit comments

Comments
 (0)