Skip to content

Commit d5bf3e6

Browse files
committed
get rid of direction
1 parent 8390bf7 commit d5bf3e6

File tree

1 file changed

+21
-36
lines changed

1 file changed

+21
-36
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,10 @@ cdef cnp.ndarray _process_arguments(
224224
object x,
225225
object n,
226226
object axis,
227-
object direction,
228227
long *axis_,
229228
long *n_,
230229
int *in_place,
231230
int *xnd,
232-
int *dir_,
233231
int realQ,
234232
):
235233
"""
@@ -239,11 +237,6 @@ cdef cnp.ndarray _process_arguments(
239237
cdef long n_max = 0
240238
cdef cnp.ndarray x_arr "xx_arrayObject"
241239

242-
if direction not in [-1, +1]:
243-
raise ValueError("Direction of FFT should +1 or -1")
244-
else:
245-
dir_[0] = -1 if direction is -1 else +1
246-
247240
# convert x to ndarray, ensure that strides are multiples of itemsize
248241
x_arr = PyArray_CheckFromAny(
249242
x, NULL, 0, 0,
@@ -379,18 +372,18 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
379372
"""
380373
cdef cnp.ndarray x_arr "x_arrayObject"
381374
cdef cnp.ndarray f_arr "f_arrayObject"
382-
cdef int xnd, n_max = 0, in_place, dir_
375+
cdef int xnd, n_max = 0, in_place
383376
cdef long n_, axis_
384377
cdef int x_type, f_type, status = 0
385378
cdef int ALL_HARMONICS = 1
386379
cdef char * c_error_msg = NULL
387380
cdef bytes py_error_msg
388381
cdef DftiCache *_cache
389382

390-
x_arr = _process_arguments(
391-
x, n, axis, direction, &axis_, &n_, &in_place, &xnd, &dir_, 0
392-
)
383+
if direction not in [-1, +1]:
384+
raise ValueError("Direction of FFT should +1 or -1")
393385

386+
x_arr = _process_arguments(x, n, axis, &axis_, &n_, &in_place, &xnd, 0)
394387
x_type = cnp.PyArray_TYPE(x_arr)
395388

396389
if out is not None:
@@ -424,7 +417,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
424417
_cache_capsule, capsule_name
425418
)
426419
if x_type is cnp.NPY_CDOUBLE:
427-
if dir_ < 0:
420+
if direction < 0:
428421
status = cdouble_mkl_ifft1d_in(
429422
x_arr, n_, <int> axis_, fsc, _cache
430423
)
@@ -433,7 +426,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
433426
x_arr, n_, <int> axis_, fsc, _cache
434427
)
435428
elif x_type is cnp.NPY_CFLOAT:
436-
if dir_ < 0:
429+
if direction < 0:
437430
status = cfloat_mkl_ifft1d_in(
438431
x_arr, n_, <int> axis_, fsc, _cache
439432
)
@@ -482,7 +475,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
482475
)
483476
if f_type is cnp.NPY_CDOUBLE:
484477
if x_type is cnp.NPY_DOUBLE:
485-
if dir_ < 0:
478+
if direction < 0:
486479
status = double_cdouble_mkl_ifft1d_out(
487480
x_arr,
488481
n_,
@@ -503,7 +496,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
503496
_cache,
504497
)
505498
elif x_type is cnp.NPY_CDOUBLE:
506-
if dir_ < 0:
499+
if direction < 0:
507500
status = cdouble_cdouble_mkl_ifft1d_out(
508501
x_arr, n_, <int> axis_, f_arr, fsc, _cache
509502
)
@@ -513,7 +506,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
513506
)
514507
else:
515508
if x_type is cnp.NPY_FLOAT:
516-
if dir_ < 0:
509+
if direction < 0:
517510
status = float_cfloat_mkl_ifft1d_out(
518511
x_arr,
519512
n_,
@@ -534,7 +527,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
534527
_cache,
535528
)
536529
elif x_type is cnp.NPY_CFLOAT:
537-
if dir_ < 0:
530+
if direction < 0:
538531
status = cfloat_cfloat_mkl_ifft1d_out(
539532
x_arr, n_, <int> axis_, f_arr, fsc, _cache
540533
)
@@ -566,18 +559,15 @@ def _r2c_fft1d_impl(
566559
"""
567560
cdef cnp.ndarray x_arr "x_arrayObject"
568561
cdef cnp.ndarray f_arr "f_arrayObject"
569-
cdef int xnd, in_place, dir_
562+
cdef int xnd, in_place
570563
cdef long n_, axis_
571564
cdef int x_type, f_type, status, requirement
572565
cdef int HALF_HARMONICS = 0 # give only positive index harmonics
573-
cdef int direction = 1 # dummy, only used for the sake of arg-processing
574566
cdef char * c_error_msg = NULL
575567
cdef bytes py_error_msg
576568
cdef DftiCache *_cache
577569

578-
x_arr = _process_arguments(
579-
x, n, axis, direction, &axis_, &n_, &in_place, &xnd, &dir_, 1
580-
)
570+
x_arr = _process_arguments(x, n, axis, &axis_, &n_, &in_place, &xnd, 1)
581571

582572
x_type = cnp.PyArray_TYPE(x_arr)
583573

@@ -668,20 +658,17 @@ def _c2r_fft1d_impl(
668658
"""
669659
cdef cnp.ndarray x_arr "x_arrayObject"
670660
cdef cnp.ndarray f_arr "f_arrayObject"
671-
cdef int xnd, in_place, dir_, int_n
661+
cdef int xnd, in_place, int_n
672662
cdef long n_, axis_
673663
cdef int x_type, f_type, status
674-
cdef int direction = 1 # dummy, only used for the sake of arg-processing
675664
cdef char * c_error_msg = NULL
676665
cdef bytes py_error_msg
677666
cdef DftiCache *_cache
678667

679668
int_n = _is_integral(n)
680669
# nn gives the number elements along axis of the input that we use
681670
nn = (n // 2 + 1) if int_n and n > 0 else n
682-
x_arr = _process_arguments(
683-
x, nn, axis, direction, &axis_, &n_, &in_place, &xnd, &dir_, 0
684-
)
671+
x_arr = _process_arguments(x, nn, axis, &axis_, &n_, &in_place, &xnd, 0)
685672
n_ = 2*(n_ - 1)
686673
if int_n and (n % 2 == 1):
687674
n_ += 1
@@ -770,12 +757,10 @@ def _direct_fftnd(
770757
cdef int err
771758
cdef cnp.ndarray x_arr "xxnd_arrayObject"
772759
cdef cnp.ndarray f_arr "ffnd_arrayObject"
773-
cdef int dir_, in_place, x_type, f_type
760+
cdef int in_place, x_type, f_type
774761

775762
if direction not in [-1, +1]:
776763
raise ValueError("Direction of FFT should +1 or -1")
777-
else:
778-
dir_ = -1 if direction is -1 else +1
779764

780765
# convert x to ndarray, ensure that strides are multiples of itemsize
781766
x_arr = PyArray_CheckFromAny(
@@ -815,12 +800,12 @@ def _direct_fftnd(
815800

816801
if in_place:
817802
if x_type == cnp.NPY_CDOUBLE:
818-
if dir_ == 1:
803+
if direction == 1:
819804
err = cdouble_cdouble_mkl_fftnd_in(x_arr, fsc)
820805
else:
821806
err = cdouble_cdouble_mkl_ifftnd_in(x_arr, fsc)
822807
elif x_type == cnp.NPY_CFLOAT:
823-
if dir_ == 1:
808+
if direction == 1:
824809
err = cfloat_cfloat_mkl_fftnd_in(x_arr, fsc)
825810
else:
826811
err = cfloat_cfloat_mkl_ifftnd_in(x_arr, fsc)
@@ -847,22 +832,22 @@ def _direct_fftnd(
847832
f_arr = _allocate_result(x_arr, -1, 0, f_type)
848833

849834
if x_type == cnp.NPY_CDOUBLE:
850-
if dir_ == 1:
835+
if direction == 1:
851836
err = cdouble_cdouble_mkl_fftnd_out(x_arr, f_arr, fsc)
852837
else:
853838
err = cdouble_cdouble_mkl_ifftnd_out(x_arr, f_arr, fsc)
854839
elif x_type == cnp.NPY_CFLOAT:
855-
if dir_ == 1:
840+
if direction == 1:
856841
err = cfloat_cfloat_mkl_fftnd_out(x_arr, f_arr, fsc)
857842
else:
858843
err = cfloat_cfloat_mkl_ifftnd_out(x_arr, f_arr, fsc)
859844
elif x_type == cnp.NPY_DOUBLE:
860-
if dir_ == 1:
845+
if direction == 1:
861846
err = double_cdouble_mkl_fftnd_out(x_arr, f_arr, fsc)
862847
else:
863848
err = double_cdouble_mkl_ifftnd_out(x_arr, f_arr, fsc)
864849
elif x_type == cnp.NPY_FLOAT:
865-
if dir_ == 1:
850+
if direction == 1:
866851
err = float_cfloat_mkl_fftnd_out(x_arr, f_arr, fsc)
867852
else:
868853
err = float_cfloat_mkl_ifftnd_out(x_arr, f_arr, fsc)

0 commit comments

Comments
 (0)