@@ -2171,8 +2171,6 @@ class Split(COp):
21712171 array([3, 4])
21722172 >>> c
21732173 array([5])
2174-
2175- TODO: Don't make a copy in C impl
21762174 """
21772175
21782176 len_splits = None
@@ -2283,29 +2281,7 @@ def R_op(self, inputs, eval_points):
22832281 return self .make_node (eval_points [0 ], * inputs [1 :]).outputs
22842282
22852283 def c_code_cache_version (self ):
2286- return (2 ,)
2287-
2288- def c_support_code (self , ** kwargs ):
2289- return """
2290- /* Return 1 if output has the correct shape. */
2291- int split_output_shape_is_correct (
2292- PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
2293- ) {
2294- return
2295- PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
2296- && memcmp(
2297- PyArray_DIMS(output),
2298- PyArray_DIMS(array_to_split),
2299- axis_to_split * sizeof(npy_intp)
2300- ) == 0
2301- && memcmp(
2302- PyArray_DIMS(output) + axis_to_split + 1,
2303- PyArray_DIMS(array_to_split) + axis_to_split + 1,
2304- (PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
2305- ) == 0
2306- && split_size == PyArray_DIM(output, axis_to_split);
2307- }
2308- """
2284+ return (3 ,)
23092285
23102286 def c_code (self , node , name , inputs , outputs , sub ):
23112287 if self .len_splits == 0 :
@@ -2316,109 +2292,96 @@ def c_code(self, node, name, inputs, outputs, sub):
23162292 outputs_pointers = "&" + (", &" .join (outputs ))
23172293 x , axis , splits = inputs
23182294 fail = sub ["fail" ]
2319- x_typenum = np .dtype (node .inputs [0 ].dtype ).num
2320- x_itemsize = np .dtype (node .inputs [0 ].dtype ).itemsize
2321- axis_dtype = node .inputs [1 ].type .dtype_specs ()[1 ]
23222295 splits_dtype = node .inputs [2 ].type .dtype_specs ()[1 ]
23232296 expected_splits_count = self .len_splits
2297+ ndim = node .inputs [0 ].type .ndim
2298+
2299+ # Most times axis is constant, inline it
2300+ # This is safe to do because the hash of the c_code includes the constant signature
2301+ if isinstance (node .inputs [1 ], Constant ):
2302+ static_axis = int (node .inputs [1 ].data )
2303+ static_axis = normalize_axis_index (static_axis , ndim )
2304+ axis_def = f"{ static_axis } ;"
2305+ axis_check = ""
2306+ else :
2307+ axis_dtype = node .inputs [1 ].type .dtype_specs ()[1 ]
2308+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2309+ axis_check = f"""
2310+ if (axis < 0){{
2311+ axis = ndim + axis;
2312+ }}
2313+ if (axis >= ndim || axis < 0) {{
2314+ PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds");
2315+ { fail }
2316+ }}
2317+ """
23242318
23252319 return f"""
2326- int ndim = PyArray_NDIM( { x } ) ;
2327- int axis = (int)(*( { axis_dtype } *)PyArray_GETPTR1( { axis } , 0));
2320+ int ndim = { ndim } ;
2321+ int axis = { axis_def }
23282322 int splits_count = PyArray_DIM({ splits } , 0);
2329- npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
2330- npy_intp* split_dims = NULL;
2331- PyObject* split_view = NULL;
2332- npy_intp data_offset;
2333- int i;
2323+ npy_intp sum_of_splits = 0, current_split_start = 0;
23342324 PyArrayObject** outputs[] = {{{ outputs_pointers } }};
2325+ npy_intp split_dims[ndim];
2326+ PyObject* split_view = NULL;
23352327
23362328 /* Check inputs. */
2337-
2338- if (splits_count != { expected_splits_count } ) {{
2339- PyErr_Format(PyExc_ValueError,
2340- "Split: splits count (%d) != expected count (%d).", splits_count, { expected_splits_count } );
2329+ if (PyArray_NDIM({ x } ) != ndim) {{
2330+ PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim");
23412331 { fail }
23422332 }}
2343-
2344- if (axis < 0) {{
2345- axis += ndim;
2346- }}
2347- if (axis < 0 || axis >= ndim) {{
2348- PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim);
2333+ if (splits_count != { expected_splits_count } ) {{
2334+ PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, { expected_splits_count } );
23492335 { fail }
23502336 }}
2351- len_along_axis = PyArray_DIM({ x } , axis);
23522337
2353- for (i = 0; i < splits_count; ++i) {{
2354- current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
2338+ { axis_check } ;
2339+
2340+ for (int i = 0; i < splits_count; ++i) {{
2341+ int current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
23552342 if (current_split_length < 0) {{
23562343 PyErr_Format(PyExc_ValueError,
23572344 "Split: you try to take a negative number (%ld) of elements.", current_split_length);
23582345 { fail }
23592346 }}
23602347 sum_of_splits += current_split_length;
23612348 }}
2362- if (sum_of_splits != len_along_axis) {{
2363- PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis);
2364- { fail }
2365- }}
2366-
2367- /* Check outputs. */
2368-
2369- split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
2370- if (split_dims == NULL) {{
2371- PyErr_NoMemory();
2349+ if (sum_of_splits != PyArray_DIM({ x } , axis)) {{
2350+ PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({ x } , axis));
23722351 { fail }
23732352 }}
23742353
2375- memcpy(split_dims, PyArray_DIMS({ x } ), ndim * sizeof(npy_intp));
2376-
2377- for (i = 0; i < splits_count; ++i) {{
2378- PyArrayObject** output = outputs[i];
2379- current_split_length = (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2380- if (*output == NULL || !split_output_shape_is_correct(*output, { x } , axis, current_split_length)) {{
2381- Py_XDECREF(*output);
2382- split_dims[axis] = current_split_length;
2383- *output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, { x_typenum } , PyArray_IS_F_CONTIGUOUS({ x } ));
2384- if (outputs == NULL) {{
2385- PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
2386- free(split_dims);
2387- { fail }
2388- }}
2389- }}
2390- }}
2391-
23922354 /* Compute split. */
2355+ memcpy(split_dims, PyArray_DIMS({ x } ), ndim * sizeof(npy_intp));
23932356
2394- for (i = 0; i < splits_count; ++i) {{
2395- current_split_length = (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2396- data_offset = PyArray_STRIDE({ x } , axis) * current_split_start;
2397- split_dims[axis] = current_split_length;
2398- split_view = PyArray_New(&PyArray_Type,
2399- ndim, split_dims,
2400- { x_typenum } ,
2401- PyArray_STRIDES({ x } ),
2402- PyArray_BYTES({ x } ) + data_offset,
2403- { x_itemsize } ,
2404- PyArray_FLAGS({ x } ),
2405- NULL);
2406- if (split_view == NULL) {{
2357+ for (int i = 0; i < splits_count; ++i) {{
2358+ Py_XDECREF(*outputs[i]);
2359+
2360+ // Create view of input
2361+ PyArray_Descr *descr = PyArray_DESCR({ x } );
2362+ Py_INCREF(descr);
2363+ npy_intp data_offset = PyArray_STRIDE({ x } , axis) * current_split_start;
2364+ *outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
2365+ descr, // PyArray_NewFromDescr steals this reference
2366+ ndim, split_dims,
2367+ PyArray_STRIDES({ x } ),
2368+ PyArray_BYTES({ x } ) + data_offset,
2369+ PyArray_FLAGS({ x } ) & ~NPY_ARRAY_OWNDATA,
2370+ NULL);
2371+
2372+ if (*outputs[i] == NULL) {{
24072373 PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
24082374 free(split_dims);
24092375 { fail }
24102376 }}
2411- if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{
2412- PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
2413- Py_XDECREF(split_view);
2414- free(split_dims);
2415- { fail }
2416- }}
2417- Py_XDECREF(split_view);
2418- current_split_start += current_split_length;
2419- }}
24202377
2421- free(split_dims);
2378+ // Set as a view of input
2379+ Py_INCREF((PyObject*){ x } );
2380+ PyArray_SetBaseObject(*outputs[i], (PyObject*){ x } );
2381+
2382+ // Update split slice pointer
2383+ current_split_start += (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2384+ }}
24222385 """
24232386
24242387
0 commit comments