Skip to content

Commit 712660e

Browse files
committed
Make reshape ndim kwarg only
This prevents surprises when passing two scalars, which are interpreted differently in the numpy API
1 parent f27ac45 commit 712660e

File tree

5 files changed

+8
-5
lines changed

5 files changed

+8
-5
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,10 @@ def grad(self, inputs, gout):
708708
shape = [x.shape[k] for k in range(x.ndim)]
709709
shape.insert(axis, repeats)
710710

711-
return [gz.reshape(shape, x.ndim + 1).sum(axis=axis), DisconnectedType()()]
711+
return [
712+
gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis),
713+
DisconnectedType()(),
714+
]
712715
elif repeats.ndim == 1:
713716
# For this implementation, we would need to specify the length
714717
# of repeats in order to split gz in the right way to sum

pytensor/tensor/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2196,7 +2196,7 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
21962196
b_reshaped = b.reshape(b_shape)
21972197

21982198
out_reshaped = dot(a_reshaped, b_reshaped)
2199-
out = out_reshaped.reshape(outshape, outndim)
2199+
out = out_reshaped.reshape(outshape, ndim=outndim)
22002200
# Make sure the broadcastable pattern of the result is correct,
22012201
# since some shape information can be lost in the reshapes.
22022202
if out.type.broadcastable != outbcast:

pytensor/tensor/rewriting/subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def transform_take(a, indices, axis):
155155

156156
ndim = a.ndim + indices.ndim - 1
157157

158-
return transform_take(a, indices.flatten(), axis).reshape(shape, ndim)
158+
return transform_take(a, indices.flatten(), axis).reshape(shape, ndim=ndim)
159159

160160

161161
def is_full_slice(x):

pytensor/tensor/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def kron(a, b):
580580
f"You passed {int(a.ndim)} and {int(b.ndim)}."
581581
)
582582
o = atm.outer(a, b)
583-
o = o.reshape(at.concatenate((a.shape, b.shape)), a.ndim + b.ndim)
583+
o = o.reshape(at.concatenate((a.shape, b.shape)), ndim=a.ndim + b.ndim)
584584
shf = o.dimshuffle(0, 2, 1, *list(range(3, o.ndim)))
585585
if shf.ndim == 3:
586586
shf = o.dimshuffle(1, 0, 2)

pytensor/tensor/var.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def all(self, axis=None, keepdims=False):
283283
# "Variable) due to Python restriction. You can use "
284284
# "PyTensorVariable.shape[0] instead.")
285285

286-
def reshape(self, shape, ndim=None):
286+
def reshape(self, shape, *, ndim=None):
287287
"""Return a reshaped view/copy of this variable.
288288
289289
Parameters

0 commit comments

Comments
 (0)