We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c01d61d commit f9a7865Copy full SHA for f9a7865
pytensor/tensor/rewriting/blockwise.py
@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
127
value, *shape = inp.owner.inputs
128
129
# Check what to do with the value of the Alloc
130
- squeezed_value = _squeeze_left(value, batch_ndim)
131
- missing_ndim = len(shape) - value.type.ndim
+ missing_ndim = inp.type.ndim - value.type.ndim
+ squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
132
if (
133
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
134
!= inp.type.broadcastable[batch_ndim:]
0 commit comments