@@ -2519,7 +2519,7 @@ def perform(self, node, inputs, output_storage):
25192519 )
25202520
25212521 def c_code_cache_version (self ):
2522- return (6 ,)
2522+ return (7 ,)
25232523
25242524 def c_code (self , node , name , inputs , outputs , sub ):
25252525 axis , * arrays = inputs
@@ -2558,16 +2558,86 @@ def c_code(self, node, name, inputs, outputs, sub):
25582558 code = f"""
25592559 int axis = { axis_def }
25602560 PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2561- PyObject* arrays_tuple = PyTuple_New( { n } ) ;
2561+ int out_is_valid = { out } != NULL ;
25622562
25632563 { axis_check }
25642564
2565- Py_XDECREF({ out } );
2566- { copy_arrays_to_tuple }
2567- { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2568- Py_DECREF(arrays_tuple);
2569- if(!{ out } ){{
2570- { fail }
2565+ if (out_is_valid) {{
2566+ // Check if we can reuse output
2567+ npy_intp join_size = 0;
2568+ npy_intp out_shape[{ ndim } ];
2569+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2570+
2571+ for (int i = 0; i < { n } ; i++) {{
2572+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2573+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2574+ { fail }
2575+ }}
2576+
2577+ join_size += PyArray_SHAPE(arrays[i])[axis];
2578+
2579+ if (i > 0){{
2580+ for (int j = 0; j < { ndim } ; j++) {{
2581+ if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2582+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2583+ { fail }
2584+ }}
2585+ }}
2586+ }}
2587+ }}
2588+
2589+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2590+ out_shape[axis] = join_size;
2591+
2592+ for (int i = 0; i < { ndim } ; i++) {{
2593+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2594+ }}
2595+ }}
2596+
2597+ if (!out_is_valid) {{
2598+ // Use PyArray_Concatenate
2599+ Py_XDECREF({ out } );
2600+ PyObject* arrays_tuple = PyTuple_New({ n } );
2601+ { copy_arrays_to_tuple }
2602+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2603+ Py_DECREF(arrays_tuple);
2604+ if(!{ out } ){{
2605+ { fail }
2606+ }}
2607+ }}
2608+ else {{
2609+ // Copy the data to the pre-allocated output buffer
2610+
2611+ // Create view into output buffer
2612+ PyArrayObject_fields *view;
2613+
2614+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2615+ Py_INCREF(PyArray_DESCR({ out } ));
2616+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2617+ PyArray_DESCR({ out } ),
2618+ { ndim } ,
2619+ PyArray_SHAPE(arrays[0]),
2620+ PyArray_STRIDES({ out } ),
2621+ PyArray_DATA({ out } ),
2622+ NPY_ARRAY_WRITEABLE,
2623+ NULL);
2624+ if (view == NULL) {{
2625+ { fail }
2626+ }}
2627+
2628+ // Copy data into output buffer
2629+ for (int i = 0; i < { n } ; i++) {{
2630+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2631+
2632+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2633+ Py_DECREF(view);
2634+ { fail }
2635+ }}
2636+
2637+ view->data += (view->dimensions[axis] * view->strides[axis]);
2638+ }}
2639+
2640+ Py_DECREF(view);
25712641 }}
25722642 """
25732643 return code
0 commit comments