Skip to content

Commit dd88a0d

Browse files
author
Vahid Tavanashad
committed
update fft tests
1 parent 8c50aff commit dd88a0d

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

dpnp/tests/test_fft.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,10 @@ def test_basic(self, dtype, n, norm):
563563

564564
result = dpnp.fft.hfft(ia, n=n, norm=norm)
565565
expected = numpy.fft.hfft(a, n=n, norm=norm)
566-
# check_only_type_kind=True since NumPy always returns float64
567-
# but dpnp return float32 if input is float32
568-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
566+
flag = True if numpy_version() < "2.0.0" else False
567+
assert_dtype_allclose(
568+
result, expected, factor=24, check_only_type_kind=flag
569+
)
569570

570571
@pytest.mark.parametrize(
571572
"dtype", get_all_dtypes(no_none=True, no_complex=True)
@@ -579,7 +580,7 @@ def test_inverse(self, dtype, n, norm):
579580
result = dpnp.fft.ihfft(ia, n=n, norm=norm)
580581
expected = numpy.fft.ihfft(a, n=n, norm=norm)
581582
flag = True if numpy_version() < "2.0.0" else False
582-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
583+
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
583584

584585
def test_error(self):
585586
a = dpnp.ones(11)
@@ -605,9 +606,8 @@ def test_basic(self, dtype, n, norm):
605606

606607
result = dpnp.fft.irfft(ia, n=n, norm=norm)
607608
expected = numpy.fft.irfft(a, n=n, norm=norm)
608-
# check_only_type_kind=True since NumPy always returns float64
609-
# but dpnp return float32 if input is float32
610-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
609+
flag = True if numpy_version() < "2.0.0" else False
610+
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
611611

612612
@pytest.mark.parametrize("dtype", get_complex_dtypes())
613613
@pytest.mark.parametrize("n", [None, 5, 8])
@@ -771,8 +771,8 @@ def test_float16(self):
771771

772772
expected = numpy.fft.rfft(a)
773773
result = dpnp.fft.rfft(ia)
774-
# check_only_type_kind=True since Intel NumPy returns complex128
775-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
774+
flag = True if numpy_version() < "2.0.0" else False
775+
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
776776

777777
@testing.with_requires("numpy>=2.0.0")
778778
@pytest.mark.parametrize("xp", [numpy, dpnp])
@@ -954,7 +954,7 @@ def test_1d_array(self):
954954

955955
result = dpnp.fft.irfftn(ia)
956956
expected = numpy.fft.irfftn(a)
957-
# TODO: change to the commented line when mkl_fft-gh-180 is merged
957+
# TODO: change to the commented line when mkl_fft-2.0.0 is released
958958
flag = True
959959
# flag = True if numpy_version() < "2.0.0" else False
960960
assert_dtype_allclose(result, expected, check_only_type_kind=flag)

0 commit comments

Comments
 (0)