@@ -2520,7 +2520,7 @@ def perform(self, node, inputs, output_storage):
25202520 )
25212521
25222522 def c_code_cache_version (self ):
2523- return (6 ,)
2523+ return (7 ,)
25242524
25252525 def c_code (self , node , name , inputs , outputs , sub ):
25262526 axis , * arrays = inputs
@@ -2559,16 +2559,86 @@ def c_code(self, node, name, inputs, outputs, sub):
25592559 code = f"""
25602560 int axis = { axis_def }
25612561 PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2562- PyObject* arrays_tuple = PyTuple_New( { n } ) ;
2562+ int out_is_valid = { out } != NULL ;
25632563
25642564 { axis_check }
25652565
2566- Py_XDECREF({ out } );
2567- { copy_arrays_to_tuple }
2568- { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2569- Py_DECREF(arrays_tuple);
2570- if(!{ out } ){{
2571- { fail }
2566+ if (out_is_valid) {{
2567+ // Check if we can reuse output
2568+ npy_intp join_size = 0;
2569+ npy_intp out_shape[{ ndim } ];
2570+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2571+
2572+ for (int i = 0; i < { n } ; i++) {{
2573+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2574+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2575+ { fail }
2576+ }}
2577+
2578+ join_size += PyArray_SHAPE(arrays[i])[axis];
2579+
2580+ if (i > 0){{
2581+ for (int j = 0; j < { ndim } ; j++) {{
2582+ if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2583+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2584+ { fail }
2585+ }}
2586+ }}
2587+ }}
2588+ }}
2589+
2590+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2591+ out_shape[axis] = join_size;
2592+
2593+ for (int i = 0; i < { ndim } ; i++) {{
2594+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2595+ }}
2596+ }}
2597+
2598+ if (!out_is_valid) {{
2599+ // Use PyArray_Concatenate
2600+ Py_XDECREF({ out } );
2601+ PyObject* arrays_tuple = PyTuple_New({ n } );
2602+ { copy_arrays_to_tuple }
2603+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2604+ Py_DECREF(arrays_tuple);
2605+ if(!{ out } ){{
2606+ { fail }
2607+ }}
2608+ }}
2609+ else {{
2610+ // Copy the data to the pre-allocated output buffer
2611+
2612+ // Create view into output buffer
2613+ PyArrayObject_fields *view;
2614+
2615+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2616+ Py_INCREF(PyArray_DESCR({ out } ));
2617+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2618+ PyArray_DESCR({ out } ),
2619+ { ndim } ,
2620+ PyArray_SHAPE(arrays[0]),
2621+ PyArray_STRIDES({ out } ),
2622+ PyArray_DATA({ out } ),
2623+ NPY_ARRAY_WRITEABLE,
2624+ NULL);
2625+ if (view == NULL) {{
2626+ { fail }
2627+ }}
2628+
2629+ // Copy data into output buffer
2630+ for (int i = 0; i < { n } ; i++) {{
2631+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2632+
2633+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2634+ Py_DECREF(view);
2635+ { fail }
2636+ }}
2637+
2638+ view->data += (view->dimensions[axis] * view->strides[axis]);
2639+ }}
2640+
2641+ Py_DECREF(view);
25722642 }}
25732643 """
25742644 return code
0 commit comments