@@ -78,11 +78,13 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
7878 shape = a_shape [index :]
7979 strides = (0 ,) + a_strides [index :]
8080 if c2c : # c2c FFT
81+ assert dpnp .issubdtype (a .dtype , dpnp .complexfloating )
8182 if a .dtype == dpnp .complex64 :
8283 dsc = fi .Complex64Descriptor (shape )
8384 else :
8485 dsc = fi .Complex128Descriptor (shape )
8586 else : # r2c/c2r FFT
87+ assert dpnp .issubdtype (a .dtype , dpnp .inexact )
8688 if a .dtype in [dpnp .float32 , dpnp .complex64 ]:
8789 dsc = fi .Real32Descriptor (shape )
8890 else :
@@ -262,12 +264,14 @@ def _copy_array(x, complex_input):
262264 in-place FFT can be performed.
263265 """
264266 dtype = x .dtype
267+ copy_flag = False
265268 if numpy .min (x .strides ) < 0 :
266269 # negative stride is not allowed in OneMKL FFT
267270 # TODO: support for negative strides will be added in the future
268271 # versions of OneMKL, see discussion in MKLD-17597
269272 copy_flag = True
270- elif complex_input and not dpnp .issubdtype (dtype , dpnp .complexfloating ):
273+
274+ if complex_input and not dpnp .issubdtype (dtype , dpnp .complexfloating ):
271275 # c2c/c2r FFT, if input is not complex, convert to complex
272276 copy_flag = True
273277 if dtype in [dpnp .float16 , dpnp .float32 ]:
@@ -279,8 +283,6 @@ def _copy_array(x, complex_input):
279283 # float32 or float64 depending on device capabilities
280284 copy_flag = True
281285 dtype = map_dtype_to_device (dpnp .float64 , x .sycl_device )
282- else :
283- copy_flag = False
284286
285287 if copy_flag :
286288 x_copy = dpnp .empty_like (x , dtype = dtype , order = "C" )
0 commit comments