@@ -110,14 +110,28 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
110110 return dsc , out_strides
111111
112112
113- def _complex_nd_fft (a , s , norm , out , forward , in_place , c2c , axes , batch_fft ):
113+ def _complex_nd_fft (
114+ a ,
115+ s ,
116+ norm ,
117+ out ,
118+ forward ,
119+ in_place ,
120+ c2c ,
121+ axes ,
122+ batch_fft ,
123+ * ,
124+ reversed_axes = True ,
125+ ):
114126 """Computes complex-to-complex FFT of the input N-D array."""
115127
116128 len_axes = len (axes )
117129 # OneMKL supports up to 3-dimensional FFT on GPU
118130 # repeated axis in OneMKL FFT is not allowed
119131 if len_axes > 3 or len (set (axes )) < len_axes :
120- axes_chunk , shape_chunk = _extract_axes_chunk (axes , s , chunk_size = 3 )
132+ axes_chunk , shape_chunk = _extract_axes_chunk (
133+ axes , s , chunk_size = 3 , reversed_axes = reversed_axes
134+ )
121135 for i , (s_chunk , a_chunk ) in enumerate (zip (shape_chunk , axes_chunk )):
122136 a = _truncate_or_pad (a , shape = s_chunk , axes = a_chunk )
123137 # if out is used in an intermediate step, it will have memory
@@ -291,7 +305,7 @@ def _copy_array(x, complex_input):
291305 return x , copy_flag
292306
293307
294- def _extract_axes_chunk (a , s , chunk_size = 3 ):
308+ def _extract_axes_chunk (a , s , chunk_size = 3 , reversed_axes = True ):
295309 """
296310 Classify the first input into a list of lists with each list containing
297311 only unique values in reverse order and its length is at most `chunk_size`.
@@ -362,7 +376,10 @@ def _extract_axes_chunk(a, s, chunk_size=3):
362376 a_chunks .append (a_current_chunk [::- 1 ])
363377 s_chunks .append (s_current_chunk [::- 1 ])
364378
365- return a_chunks [::- 1 ], s_chunks [::- 1 ]
379+ if reversed_axes :
380+ return a_chunks [::- 1 ], s_chunks [::- 1 ]
381+
382+ return a_chunks , s_chunks
366383
367384
368385def _fft (a , norm , out , forward , in_place , c2c , axes , batch_fft = True ):
@@ -531,9 +548,12 @@ def _validate_out_keyword(a, out, s, axes, c2c, c2r, r2c):
531548 expected_shape [axes [- 1 ]] = s [- 1 ] // 2 + 1
532549 elif c2c :
533550 expected_shape [axes [- 1 ]] = s [- 1 ]
534- for s_i , axis in zip (s [- 2 ::- 1 ], axes [- 2 ::- 1 ]):
535- expected_shape [axis ] = s_i
551+ if r2c or c2c :
552+ for s_i , axis in zip (s [- 2 ::- 1 ], axes [- 2 ::- 1 ]):
553+ expected_shape [axis ] = s_i
536554 if c2r :
555+ for s_i , axis in zip (s [:- 1 ], axes [:- 1 ]):
556+ expected_shape [axis ] = s_i
537557 expected_shape [axes [- 1 ]] = s [- 1 ]
538558
539559 if out .shape != tuple (expected_shape ):
@@ -717,6 +737,7 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
717737 c2c = True ,
718738 axes = axes [:- 1 ],
719739 batch_fft = a .ndim != len_axes - 1 ,
740+ reversed_axes = False ,
720741 )
721742 a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
722743 if c2r :
0 commit comments