@@ -411,22 +411,35 @@ def test_fft_error(self, xp):
411411 assert_raises (IndexError , xp .fft .fft2 , a )
412412
413413
414+ @pytest .mark .parametrize ("func" , ["fftfreq" , "rfftfreq" ])
414415class TestFftfreq :
415- @pytest .mark .parametrize ("func" , ["fftfreq" , "rfftfreq" ])
416416 @pytest .mark .parametrize ("n" , [10 , 20 ])
417417 @pytest .mark .parametrize ("d" , [0.5 , 2 ])
418418 def test_fftfreq (self , func , n , d ):
419- expected = getattr (dpnp .fft , func )(n , d )
420- result = getattr (numpy .fft , func )(n , d )
421- assert_dtype_allclose (expected , result )
419+ result = getattr (dpnp .fft , func )(n , d )
420+ expected = getattr (numpy .fft , func )(n , d )
421+ assert_dtype_allclose (result , expected )
422422
423- @pytest .mark .parametrize ("func" , ["fftfreq" , "rfftfreq" ])
424- def test_error (self , func ):
425- # n should be an integer
426- assert_raises (ValueError , getattr (dpnp .fft , func ), 10.0 )
423+ @pytest .mark .parametrize ("dt" , [None ] + get_float_dtypes ())
424+ def test_dtype (self , func , dt ):
425+ n = 15
426+ result = getattr (dpnp .fft , func )(n , dtype = dt )
427+ expected = getattr (numpy .fft , func )(n ).astype (dt )
428+ assert_dtype_allclose (result , expected )
427429
428- # d should be an scalar
429- assert_raises (ValueError , getattr (dpnp .fft , func ), 10 , (2 ,))
430+ def test_error (self , func ):
431+ func = getattr (dpnp .fft , func )
432+ # n must be an integer
433+ assert_raises (ValueError , func , 10.0 )
434+
435+ # d must be an scalar
436+ assert_raises (ValueError , func , 10 , (2 ,))
437+
438+ # dtype must be None or a real-valued floating-point dtype
439+ # which is passed as a keyword argument only
440+ assert_raises (TypeError , func , 10 , 2 , None )
441+ assert_raises (ValueError , func , 10 , 2 , dtype = dpnp .intp )
442+ assert_raises (ValueError , func , 10 , 2 , dtype = dpnp .complex64 )
430443
431444
432445class TestFftn :
0 commit comments