11#section support_code_apply
22
3- int APPLY_SPECIFIC (cpu_dimshuffle )(PyArrayObject * input , PyArrayObject * * res ,
4- PARAMS_TYPE * params ) {
5-
6- // This points to either the original input or a copy we create below.
7- // Either way, this is what we should be working on/with.
8- PyArrayObject * _input ;
9-
10- if (* res )
11- Py_XDECREF (* res );
12-
13- if (params -> inplace ) {
14- _input = input ;
15- Py_INCREF ((PyObject * )_input );
16- } else {
17- _input = (PyArrayObject * )PyArray_FromAny (
18- (PyObject * )input , NULL , 0 , 0 , NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY ,
19- NULL );
20- }
21-
22- PyArray_Dims permute ;
23-
24- if (!PyArray_IntpConverter ((PyObject * )params -> transposition , & permute )) {
25- return 1 ;
26- }
27-
28- /*
29- res = res.transpose(self.transposition)
30- */
31- PyArrayObject * transposed_input =
32- (PyArrayObject * )PyArray_Transpose (_input , & permute );
3+ int APPLY_SPECIFIC (cpu_dimshuffle )(PyArrayObject * input , PyArrayObject * * res , PARAMS_TYPE * params ) {
4+ npy_int64 * new_order ;
5+ npy_intp nd_in ;
6+ npy_intp nd_out ;
7+ npy_intp * dimensions ;
8+ npy_intp * strides ;
9+
10+ // This points to either the original input or a copy we create below.
11+ // Either way, this is what we should be working on/with.
12+ PyArrayObject * _input ;
13+
14+ if (!PyArray_IS_C_CONTIGUOUS (params -> _new_order )) {
15+ PyErr_SetString (PyExc_RuntimeError , "DimShuffle: param _new_order must be C-contiguous." );
16+ return 1 ;
17+ }
18+ new_order = (npy_int64 * ) PyArray_DATA (params -> _new_order );
19+ nd_in = (npy_intp )(params -> input_ndim );
20+ nd_out = PyArray_SIZE (params -> _new_order );
3321
34- Py_DECREF (_input );
22+ if (PyArray_NDIM (input ) != nd_in ) {
23+ PyErr_SetString (PyExc_NotImplementedError , "DimShuffle: Input has less dimensions than expected." );
24+ return 1 ;
25+ }
3526
36- PyDimMem_FREE (permute .ptr );
27+ // Compute new dimensions and strides
28+ dimensions = (npy_intp * ) malloc (nd_out * sizeof (npy_intp ));
29+ strides = (npy_intp * ) malloc (nd_out * sizeof (npy_intp ));
30+ if (dimensions == NULL || strides == NULL ) {
31+ PyErr_NoMemory ();
32+ free (dimensions );
33+ free (strides );
34+ return 1 ;
35+ };
36+
37+ npy_intp original_size = PyArray_SIZE (_input );
38+ npy_intp new_size = 1 ;
39+ for (npy_intp i = 0 ; i < nd_out ; ++ i ) {
40+ if (new_order [i ] != -1 ) {
41+ dimensions [i ] = PyArray_DIMS (_input )[new_order [i ]];
42+ strides [i ] = PyArray_DIMS (_input )[new_order [i ]] == 1 ? 0 : PyArray_STRIDES (_input )[new_order [i ]];
43+ } else {
44+ dimensions [i ] = 1 ;
45+ strides [i ] = 0 ;
46+ }
47+ new_size *= dimensions [i ];
48+ }
3749
38- npy_intp * res_shape = PyArray_DIMS (transposed_input );
39- npy_intp N_shuffle = PyArray_SIZE (params -> shuffle );
40- npy_intp N_augment = PyArray_SIZE (params -> augment );
41- npy_intp N = N_augment + N_shuffle ;
42- npy_intp * _reshape_shape = PyDimMem_NEW (N );
50+ if (original_size != new_size ) {
51+ PyErr_SetString (PyExc_ValueError , "DimShuffle: Attempting to squeeze axes with size not equal to one." );
52+ free (dimensions );
53+ free (strides );
54+ return 1 ;
55+ }
4356
44- if (_reshape_shape == NULL ) {
45- PyErr_NoMemory ();
46- return 1 ;
47- }
57+ if (* res )
58+ Py_XDECREF (* res );
4859
49- /*
50- shape = list(res.shape[: len(self.shuffle)])
51- for augm in self.augment:
52- shape.insert(augm, 1)
53- */
54- npy_intp aug_idx = 0 ;
55- int res_idx = 0 ;
56- for (npy_intp i = 0 ; i < N ; i ++ ) {
57- if (aug_idx < N_augment &&
58- i == * ((npy_intp * )PyArray_GetPtr (params -> augment , & aug_idx ))) {
59- _reshape_shape [i ] = 1 ;
60- aug_idx ++ ;
60+ if (params -> inplace ) {
61+ _input = input ;
62+ Py_INCREF ((PyObject * )_input );
6163 } else {
62- _reshape_shape [i ] = res_shape [res_idx ];
63- res_idx ++ ;
64+ _input = (PyArrayObject * )PyArray_FromAny (
65+ (PyObject * )input , NULL , 0 , 0 , NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY ,
66+ NULL );
6467 }
65- }
6668
67- PyArray_Dims reshape_shape = {.ptr = _reshape_shape , .len = (int )N };
68-
69- /* res = res.reshape(shape) */
70- * res = (PyArrayObject * )PyArray_Newshape (transposed_input , & reshape_shape ,
71- NPY_CORDER );
72-
73- Py_DECREF (transposed_input );
69+ // Create the new array.
70+ * res = (PyArrayObject * )PyArray_New (& PyArray_Type , nd_out , dimensions ,
71+ PyArray_TYPE (_input ), strides ,
72+ PyArray_DATA (_input ), PyArray_ITEMSIZE (_input ),
73+ // borrow only the writable flag from the base
74+ // the NPY_OWNDATA flag will default to 0.
75+ (NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE (_input )),
76+ NULL );
77+
78+ if (* res == NULL ) {
79+ free (dimensions );
80+ free (strides );
81+ return 1 ;
82+ }
7483
75- PyDimMem_FREE (reshape_shape .ptr );
84+ // recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
85+ PyArray_UpdateFlags (* res , NPY_ARRAY_UPDATE_ALL );
7686
77- if (!* res ) {
78- return 1 ;
79- }
87+ // we are making a view in both inplace and non-inplace cases
88+ PyArray_SetBaseObject (* res , (PyObject * )_input );
8089
81- return 0 ;
82- }
90+ free (strides );
91+ free (dimensions );
92+ return 0 ;
93+ }
0 commit comments