Skip to content

Commit 5bed8a5

Browse files
committed
Cleanup Split methods
1 parent bcac445 commit 5bed8a5

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

pytensor/tensor/basic.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,21 +2224,21 @@ def make_node(self, x, axis, splits):
22242224

22252225
return Apply(self, inputs, outputs)
22262226

2227-
def perform(self, node, inputs, outputs):
2227+
def perform(self, node, inputs, outputs_storage):
22282228
x, axis, splits = inputs
22292229

22302230
if len(splits) != self.len_splits:
22312231
raise ValueError("Length of splits is not equal to n_splits")
2232-
if np.sum(splits) != x.shape[axis]:
2232+
if splits.sum() != x.shape[axis]:
22332233
raise ValueError(
2234-
f"Split sizes sum to {np.sum(splits)}; expected {x.shape[axis]}"
2234+
f"Split sizes sum to {splits.sum()}; expected {x.shape[axis]}"
22352235
)
2236-
if np.any(splits < 0):
2236+
if (splits < 0).any():
22372237
raise ValueError("Split sizes cannot be negative")
22382238

22392239
split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
2240-
for i, out in enumerate(split_outs):
2241-
outputs[i][0] = out
2240+
for output_storage, out in enumerate(outputs_storage, split_outs):
2241+
outputs_storage[0] = out
22422242

22432243
def infer_shape(self, fgraph, node, in_shapes):
22442244
axis = node.inputs[1]
@@ -2252,10 +2252,10 @@ def infer_shape(self, fgraph, node, in_shapes):
22522252
out_shapes.append(temp)
22532253
return out_shapes
22542254

2255-
def grad(self, inputs, g_outputs):
2255+
def L_op(self, inputs, outputs, g_outputs):
22562256
"""Join the gradients along the axis that was used to split x."""
22572257
x, axis, n = inputs
2258-
outputs = self(*inputs, return_list=True)
2258+
22592259
# If all the output gradients are disconnected, then so are the inputs
22602260
if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs):
22612261
return [

0 commit comments

Comments
 (0)