@@ -2537,58 +2537,185 @@ def perform(self, node, inputs, output_storage):
25372537 )
25382538
25392539 def c_code_cache_version (self ):
2540- return (5 ,)
2540+ return None
2541+ return (6 ,)
25412542
25422543 def c_code (self , node , name , inputs , outputs , sub ):
2543- axis , tens = inputs [0 ], inputs [1 :]
2544- view = - 1
2545- non_empty_tensor = tens [view ]
2546- input_1 = tens [0 ]
2547- l = len (tens )
2548- (out ,) = outputs
2549- fail = sub ["fail" ]
2550- adtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2544+ axis , * arrays = inputs
2545+ [out ] = outputs
25512546
2552- copy_to_list = (
2553- f"""Py_INCREF({ inp } ); PyList_SetItem(list, { i } , (PyObject*){ inp } );"""
2554- for i , inp in enumerate (tens )
2555- )
2547+ n = len (arrays )
2548+ out_dtype = node .outputs [0 ].type .dtype_specs ()[2 ]
2549+ out_itemsize = np .dtype (node .outputs [0 ].dtype ).itemsize
2550+ ndim = node .outputs [0 ].type .ndim
2551+ fail = sub ["fail" ]
25562552
2557- copy_inputs_to_list = "\n " .join (copy_to_list )
2558- n = len (tens )
2553+ # Most times axis is constant, inline it
2554+ # This is safe to do because the hash of the c_code includes the constant signature
2555+ if isinstance (node .inputs [0 ], Constant ):
2556+ static_axis = int (node .inputs [0 ].data )
2557+ static_axis = normalize_axis_index (static_axis , ndim )
2558+ axis_def = f"{ static_axis } ;"
2559+ axis_check = ""
2560+ else :
2561+ axis_dtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2562+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2563+ axis_check = f"""
2564+ if (axis < 0){{
2565+ axis = { ndim } + axis;
2566+ }}
2567+ if (axis >= { ndim } || axis < 0) {{
2568+ PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
2569+ { fail }
2570+ }}
2571+ """
25592572
25602573 code = f"""
2561- int axis = (({ adtype } *)PyArray_DATA({ axis } ))[0];
2562- PyObject* list = PyList_New({ l } );
2563- { copy_inputs_to_list }
2564- int tensors_lens_sum;
2565- if({ view } != -1) {{
2566- tensors_lens_sum = 0;
2567-
2568- for(int i=0; i < { n } ; i++){{
2569- tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
2574+ int axis = { axis_def }
2575+ PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2576+ npy_intp out_shape[{ ndim } ];
2577+ npy_intp join_size = 0;
2578+ int out_is_valid = 0;
2579+ PyArrayObject_fields *view;
2580+
2581+ // Validate input shapes and compute join size
2582+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2583+
2584+ { axis_check }
2585+
2586+ for (int i = 0; i < { n } ; i++) {{
2587+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2588+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2589+ { fail }
2590+ }}
2591+
2592+ join_size += PyArray_SHAPE(arrays[i])[axis];
2593+
2594+ if(i > 0){{
2595+ for (int j = 0; j < { ndim } ; j++) {{
2596+ if((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2597+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2598+ { fail }
2599+ }}
2600+ }}
2601+ }}
25702602 }}
2571- tensors_lens_sum -= PyArray_DIM({ non_empty_tensor } , axis);
2572- }}
2573- if({ view } != -1 && tensors_lens_sum == 0) {{
2574- Py_XDECREF({ out } );
2575- Py_INCREF({ non_empty_tensor } );
2576- { out } = { non_empty_tensor } ;
2577- }}else{{
2578- //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
2579- int ndim = PyArray_NDIM({ input_1 } );
2580- if( axis < -ndim ){{
2581- PyErr_Format(PyExc_IndexError,
2582- "Join axis %d out of bounds [0, %d)", axis, ndim);
2583- { fail }
2603+
2604+ // Define dimensions of output array
2605+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2606+ out_shape[axis] = join_size;
2607+
2608+ // Reuse output or allocate new one
2609+ if ({ out } != NULL) {{
2610+ out_is_valid = (PyArray_NDIM({ out } ) == { ndim } );
2611+ for (int i = 0; i < { ndim } ; i++) {{
2612+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2613+ }}
25842614 }}
2585- Py_XDECREF({ out } );
2586- { out } = (PyArrayObject *)PyArray_Concatenate(list, axis);
2587- Py_DECREF(list);
2588- if(!{ out } ){{
2615+
2616+ if (!out_is_valid) {{
2617+ Py_XDECREF({ out } );
2618+
2619+ // Find best memory layout to match the input tensors
2620+ // Adapted from numpy PyArray_CreateMultiSortedStridePerm
2621+ // https://github.com/numpy/numpy/blob/214b9f7c6d27f48b163dd7adbf9de368ad59859f/numpy/_core/src/multiarray/shape.c#L801
2622+ int strideperm[{ ndim } ] = {{{ ',' .join (map (str , range (ndim )))} }};
2623+ npy_intp strides[{ ndim } ];
2624+
2625+ // Sort strides (insertion sort)
2626+ for (int i0 = 1; i0 < { ndim } ; ++i0) {{
2627+ int ipos = i0;
2628+ int ax_j0 = strideperm[i0];
2629+
2630+ for (int i1 = i0 - 1; i1 >= 0; --i1) {{
2631+ int ambig = 1, shouldswap = 0;
2632+ int ax_j1 = strideperm[i1];
2633+
2634+ for (int iarrays = 0; iarrays < { n } ; ++iarrays) {{
2635+ if (PyArray_SHAPE(arrays[iarrays])[ax_j0] != 1 && PyArray_SHAPE(arrays[iarrays])[ax_j1] != 1) {{
2636+ npy_intp stride0 = PyArray_STRIDES(arrays[iarrays])[ax_j0];
2637+ npy_intp stride1 = PyArray_STRIDES(arrays[iarrays])[ax_j1];
2638+ if (stride0 < 0) stride0 = -stride0;
2639+ if (stride1 < 0) stride1 = -stride1;
2640+
2641+ if (stride0 <= stride1) {{
2642+ shouldswap = 0;
2643+ }}
2644+ else if (ambig) {{
2645+ shouldswap = 1;
2646+ }}
2647+ ambig = 0;
2648+ }}
2649+ }}
2650+
2651+ if (!ambig) {{
2652+ if (shouldswap) {{
2653+ ipos = i1;
2654+ }}
2655+ else {{
2656+ break;
2657+ }}
2658+ }}
2659+ }}
2660+
2661+ if (ipos != i0) {{
2662+ for (int i1 = i0; i1 > ipos; --i1) {{
2663+ strideperm[i1] = strideperm[i1-1];
2664+ }}
2665+ strideperm[ipos] = ax_j0;
2666+ }}
2667+ }}
2668+
2669+ // Calculate strides based on sorted order
2670+ npy_intp stride = { out_itemsize } ;
2671+ for (int i = { ndim } -1; i >= 0; --i) {{
2672+ int ax = strideperm[i];
2673+ strides[ax] = stride;
2674+ stride *= out_shape[ax];
2675+ }}
2676+
2677+ { out } = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
2678+ PyArray_DescrFromType({ out_dtype } ),
2679+ { ndim } ,
2680+ out_shape,
2681+ strides,
2682+ NULL, /* data */
2683+ NPY_ARRAY_DEFAULT,
2684+ NULL);
2685+
2686+ if ({ out } == NULL) {{
2687+ { fail }
2688+ }}
2689+ }}
2690+
2691+ // Create view into output buffer
2692+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2693+ Py_INCREF(PyArray_DESCR({ out } ));
2694+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2695+ PyArray_DESCR({ out } ),
2696+ { ndim } ,
2697+ PyArray_SHAPE(arrays[0]),
2698+ PyArray_STRIDES({ out } ),
2699+ PyArray_DATA({ out } ),
2700+ NPY_ARRAY_WRITEABLE,
2701+ NULL);
2702+ if (view == NULL) {{
25892703 { fail }
25902704 }}
2591- }}
2705+
2706+ // Copy data into output buffer
2707+ for (int i = 0; i < { n } ; i++) {{
2708+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2709+
2710+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2711+ Py_DECREF(view);
2712+ { fail }
2713+ }}
2714+
2715+ view->data += (view->dimensions[axis] * view->strides[axis]);
2716+ }}
2717+
2718+ Py_DECREF(view);
25922719 """
25932720 return code
25942721
0 commit comments