@@ -2537,58 +2537,123 @@ def perform(self, node, inputs, output_storage):
25372537 )
25382538
25392539 def c_code_cache_version (self ):
2540- return (5 ,)
2540+ return (6 ,)
25412541
25422542 def c_code (self , node , name , inputs , outputs , sub ):
2543- axis , tens = inputs [0 ], inputs [1 :]
2544- view = - 1
2545- non_empty_tensor = tens [view ]
2546- input_1 = tens [0 ]
2547- l = len (tens )
2548- (out ,) = outputs
2549- fail = sub ["fail" ]
2550- adtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2543+ axis , * arrays = inputs
2544+ [out ] = outputs
25512545
2552- copy_to_list = (
2553- f"""Py_INCREF( { inp } ); PyList_SetItem(list, { i } , (PyObject*) { inp } );"""
2554- for i , inp in enumerate ( tens )
2555- )
2546+ n = len ( arrays )
2547+ out_dtype = node . outputs [ 0 ]. type . dtype_specs ()[ 2 ]
2548+ ndim = node . outputs [ 0 ]. type . ndim
2549+ fail = sub [ "fail" ]
25562550
2557- copy_inputs_to_list = "\n " .join (copy_to_list )
2558- n = len (tens )
2551+ # Most times axis is constant, inline it
2552+ # This is safe to do because the hash of the c_code includes the constant signature
2553+ if isinstance (node .inputs [0 ], Constant ):
2554+ static_axis = int (node .inputs [0 ].data )
2555+ static_axis = normalize_axis_index (static_axis , ndim )
2556+ axis_def = f"{ static_axis } ;"
2557+ axis_check = ""
2558+ else :
2559+ axis_dtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2560+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2561+ axis_check = f"""
2562+ if (axis < 0){{
2563+ axis = { ndim } + axis;
2564+ }}
2565+ if (axis >= { ndim } || axis < 0) {{
2566+ PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
2567+ { fail }
2568+ }}
2569+ """
25592570
25602571 code = f"""
2561- int axis = (({ adtype } *)PyArray_DATA({ axis } ))[0];
2562- PyObject* list = PyList_New({ l } );
2563- { copy_inputs_to_list }
2564- int tensors_lens_sum;
2565- if({ view } != -1) {{
2566- tensors_lens_sum = 0;
2567-
2568- for(int i=0; i < { n } ; i++){{
2569- tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
2572+ int axis = { axis_def }
2573+ PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2574+ int out_is_valid = 0;
2575+ npy_intp join_size = 0;
2576+ npy_intp offset = 0;
2577+
2578+ // Validate input shapes and compute join size
2579+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2580+
2581+ { axis_check }
2582+
2583+ for (int i = 0; i < { n } ; i++) {{
2584+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2585+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2586+ { fail }
2587+ }}
2588+
2589+ for (int j = 0; j < { ndim } ; j++) {{
2590+ if (j == axis){{
2591+ join_size += PyArray_DIM(arrays[i], j);
2592+ }}
2593+ else if(PyArray_DIM(arrays[i], j) != shape[j]) {{
2594+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2595+ { fail }
2596+ }}
2597+ }}
25702598 }}
2571- tensors_lens_sum -= PyArray_DIM({ non_empty_tensor } , axis);
2572- }}
2573- if({ view } != -1 && tensors_lens_sum == 0) {{
2574- Py_XDECREF({ out } );
2575- Py_INCREF({ non_empty_tensor } );
2576- { out } = { non_empty_tensor } ;
2577- }}else{{
2578- //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
2579- int ndim = PyArray_NDIM({ input_1 } );
2580- if( axis < -ndim ){{
2581- PyErr_Format(PyExc_IndexError,
2582- "Join axis %d out of bounds [0, %d)", axis, ndim);
2583- { fail }
2599+
2600+ // Define dimensions of output array
2601+ npy_intp out_dims[{ ndim } ];
2602+ memcpy(out_dims, shape, { ndim } * sizeof(npy_intp));
2603+ out_dims[axis] = join_size;
2604+
2605+ // Reuse output or allocate new one
2606+ if ({ out } != NULL) {{
2607+ out_is_valid = (PyArray_NDIM({ out } ) == { ndim } );
2608+ for (int i = 0; i < { ndim } ; i++) {{
2609+ out_is_valid &= (PyArray_DIM({ out } , i) == out_dims[i]);
2610+ }}
25842611 }}
2585- Py_XDECREF({ out } );
2586- { out } = (PyArrayObject *)PyArray_Concatenate(list, axis);
2587- Py_DECREF(list);
2588- if(!{ out } ){{
2589- { fail }
2612+
2613+ if (!out_is_valid) {{
2614+ Py_XDECREF({ out } );
2615+ { out } = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
2616+ PyArray_DescrFromType({ out_dtype } ),
2617+ { ndim } ,
2618+ out_dims,
2619+ NULL, /* strides */
2620+ NULL, /* data */
2621+ NPY_ARRAY_DEFAULT,
2622+ NULL);
2623+
2624+ if ({ out } == NULL) {{
2625+ { fail }
2626+ }}
2627+ }}
2628+
2629+ // Copy data into output buffer
2630+ for (int i = 0; i < { n } ; i++) {{
2631+ PyArrayObject *arr = arrays[i];
2632+
2633+ // Create view into output buffer
2634+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2635+ Py_INCREF(PyArray_DESCR({ out } ));
2636+ PyArrayObject *view = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
2637+ PyArray_DESCR({ out } ),
2638+ { ndim } ,
2639+ PyArray_SHAPE(arr),
2640+ PyArray_STRIDES({ out } ),
2641+ PyArray_BYTES({ out } ) + (offset * PyArray_STRIDES({ out } )[axis]),
2642+ NPY_ARRAY_WRITEABLE,
2643+ NULL);
2644+ if (view == NULL) {{
2645+ { fail }
2646+ }}
2647+
2648+ // Write to it
2649+ int success = PyArray_CopyInto(view, arr);
2650+ Py_DECREF(view);
2651+ if (success != 0) {{
2652+ { fail }
2653+ }}
2654+
2655+ offset += PyArray_DIM(arr, axis);
25902656 }}
2591- }}
25922657 """
25932658 return code
25942659
0 commit comments