@@ -2224,21 +2224,21 @@ def make_node(self, x, axis, splits):
22242224
22252225 return Apply (self , inputs , outputs )
22262226
2227- def perform (self , node , inputs , outputs ):
2227+ def perform (self , node , inputs , outputs_storage ):
22282228 x , axis , splits = inputs
22292229
22302230 if len (splits ) != self .len_splits :
22312231 raise ValueError ("Length of splits is not equal to n_splits" )
2232- if np .sum (splits ) != x .shape [axis ]:
2232+ if splits .sum () != x .shape [axis ]:
22332233 raise ValueError (
2234- f"Split sizes sum to { np .sum (splits )} ; expected { x .shape [axis ]} "
2234+ f"Split sizes sum to { splits .sum ()} ; expected { x .shape [axis ]} "
22352235 )
2236- if np . any (splits < 0 ):
2236+ if (splits < 0 ). any ( ):
22372237 raise ValueError ("Split sizes cannot be negative" )
22382238
22392239 split_outs = np .split (x , np .cumsum (splits [:- 1 ]), axis = axis )
2240- for i , out in enumerate (split_outs ):
2241- outputs [ i ] [0 ] = out
2240+ for output_storage , out in enumerate (outputs_storage , split_outs ):
2241+ outputs_storage [0 ] = out
22422242
22432243 def infer_shape (self , fgraph , node , in_shapes ):
22442244 axis = node .inputs [1 ]
@@ -2252,10 +2252,10 @@ def infer_shape(self, fgraph, node, in_shapes):
22522252 out_shapes .append (temp )
22532253 return out_shapes
22542254
2255- def grad (self , inputs , g_outputs ):
2255+ def L_op (self , inputs , outputs , g_outputs ):
22562256 """Join the gradients along the axis that was used to split x."""
22572257 x , axis , n = inputs
2258- outputs = self ( * inputs , return_list = True )
2258+
22592259 # If all the output gradients are disconnected, then so are the inputs
22602260 if builtins .all (isinstance (g .type , DisconnectedType ) for g in g_outputs ):
22612261 return [
0 commit comments