@@ -2537,7 +2537,7 @@ def perform(self, node, inputs, output_storage):
25372537 )
25382538
25392539 def c_code_cache_version (self ):
2540- return (6 ,)
2540+ return (7 ,)
25412541
25422542 def c_code (self , node , name , inputs , outputs , sub ):
25432543 axis , * arrays = inputs
@@ -2576,16 +2576,86 @@ def c_code(self, node, name, inputs, outputs, sub):
25762576 code = f"""
25772577 int axis = { axis_def }
25782578 PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2579- PyObject* arrays_tuple = PyTuple_New( { n } ) ;
2579+ int out_is_valid = { out } != NULL ;
25802580
25812581 { axis_check }
25822582
2583- Py_XDECREF({ out } );
2584- { copy_arrays_to_tuple }
2585- { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2586- Py_DECREF(arrays_tuple);
2587- if(!{ out } ){{
2588- { fail }
2583+ if (out_is_valid) {{
2584+ // Check if we can reuse output
2585+ npy_intp join_size = 0;
2586+ npy_intp out_shape[{ ndim } ];
2587+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2588+
2589+ for (int i = 0; i < { n } ; i++) {{
2590+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2591+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2592+ { fail }
2593+ }}
2594+
2595+ join_size += PyArray_SHAPE(arrays[i])[axis];
2596+
2597+ if (i > 0){{
2598+ for (int j = 0; j < { ndim } ; j++) {{
2599+ if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2600+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2601+ { fail }
2602+ }}
2603+ }}
2604+ }}
2605+ }}
2606+
2607+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2608+ out_shape[axis] = join_size;
2609+
2610+ for (int i = 0; i < { ndim } ; i++) {{
2611+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2612+ }}
2613+ }}
2614+
2615+ if (!out_is_valid) {{
2616+ // Use PyArray_Concatenate
2617+ Py_XDECREF({ out } );
2618+ PyObject* arrays_tuple = PyTuple_New({ n } );
2619+ { copy_arrays_to_tuple }
2620+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2621+ Py_DECREF(arrays_tuple);
2622+ if(!{ out } ){{
2623+ { fail }
2624+ }}
2625+ }}
2626+ else {{
2627+ // Copy the data to the pre-allocated output buffer
2628+
2629+ // Create view into output buffer
2630+ PyArrayObject_fields *view;
2631+
2632+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2633+ Py_INCREF(PyArray_DESCR({ out } ));
2634+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2635+ PyArray_DESCR({ out } ),
2636+ { ndim } ,
2637+ PyArray_SHAPE(arrays[0]),
2638+ PyArray_STRIDES({ out } ),
2639+ PyArray_DATA({ out } ),
2640+ NPY_ARRAY_WRITEABLE,
2641+ NULL);
2642+ if (view == NULL) {{
2643+ { fail }
2644+ }}
2645+
2646+ // Copy data into output buffer
2647+ for (int i = 0; i < { n } ; i++) {{
2648+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2649+
2650+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2651+ Py_DECREF(view);
2652+ { fail }
2653+ }}
2654+
2655+ view->data += (view->dimensions[axis] * view->strides[axis]);
2656+ }}
2657+
2658+ Py_DECREF(view);
25892659 }}
25902660 """
25912661 return code
0 commit comments