@@ -2518,7 +2518,7 @@ def perform(self, node, inputs, output_storage):
25182518 )
25192519
25202520 def c_code_cache_version (self ):
2521- return (6 ,)
2521+ return (7 ,)
25222522
25232523 def c_code (self , node , name , inputs , outputs , sub ):
25242524 axis , * arrays = inputs
@@ -2557,16 +2557,86 @@ def c_code(self, node, name, inputs, outputs, sub):
25572557 code = f"""
25582558 int axis = { axis_def }
25592559 PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2560- PyObject* arrays_tuple = PyTuple_New( { n } ) ;
2560+ int out_is_valid = { out } != NULL ;
25612561
25622562 { axis_check }
25632563
2564- Py_XDECREF({ out } );
2565- { copy_arrays_to_tuple }
2566- { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2567- Py_DECREF(arrays_tuple);
2568- if(!{ out } ){{
2569- { fail }
2564+ if (out_is_valid) {{
2565+ // Check if we can reuse output
2566+ npy_intp join_size = 0;
2567+ npy_intp out_shape[{ ndim } ];
2568+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2569+
2570+ for (int i = 0; i < { n } ; i++) {{
2571+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2572+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2573+ { fail }
2574+ }}
2575+
2576+ join_size += PyArray_SHAPE(arrays[i])[axis];
2577+
2578+ if (i > 0){{
2579+ for (int j = 0; j < { ndim } ; j++) {{
2580+ if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2581+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2582+ { fail }
2583+ }}
2584+ }}
2585+ }}
2586+ }}
2587+
2588+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2589+ out_shape[axis] = join_size;
2590+
2591+ for (int i = 0; i < { ndim } ; i++) {{
2592+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2593+ }}
2594+ }}
2595+
2596+ if (!out_is_valid) {{
2597+ // Use PyArray_Concatenate
2598+ Py_XDECREF({ out } );
2599+ PyObject* arrays_tuple = PyTuple_New({ n } );
2600+ { copy_arrays_to_tuple }
2601+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2602+ Py_DECREF(arrays_tuple);
2603+ if(!{ out } ){{
2604+ { fail }
2605+ }}
2606+ }}
2607+ else {{
2608+ // Copy the data to the pre-allocated output buffer
2609+
2610+ // Create view into output buffer
2611+ PyArrayObject_fields *view;
2612+
2613+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2614+ Py_INCREF(PyArray_DESCR({ out } ));
2615+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2616+ PyArray_DESCR({ out } ),
2617+ { ndim } ,
2618+ PyArray_SHAPE(arrays[0]),
2619+ PyArray_STRIDES({ out } ),
2620+ PyArray_DATA({ out } ),
2621+ NPY_ARRAY_WRITEABLE,
2622+ NULL);
2623+ if (view == NULL) {{
2624+ { fail }
2625+ }}
2626+
2627+ // Copy data into output buffer
2628+ for (int i = 0; i < { n } ; i++) {{
2629+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2630+
2631+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2632+ Py_DECREF(view);
2633+ { fail }
2634+ }}
2635+
2636+ view->data += (view->dimensions[axis] * view->strides[axis]);
2637+ }}
2638+
2639+ Py_DECREF(view);
25702640 }}
25712641 """
25722642 return code
0 commit comments