Skip to content

Commit 90a26ab

Browse files
committed
Fix the array_api fft creation functions to use the custom CPU_DEVICE object
Original NumPy Commit: 7f354e500f85ec335dba4fdb53bd764c777965c0
1 parent 2c4b6c5 commit 90a26ab

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

array_api_strict/fft.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
float32,
1414
complex64,
1515
)
16-
from ._array_object import Array
16+
from ._array_object import Array, CPU_DEVICE
1717
from ._data_type_functions import astype
1818

1919
import numpy as np
@@ -244,7 +244,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar
244244
245245
See its docstring for more information.
246246
"""
247-
if device not in ["cpu", None]:
247+
if device not in [CPU_DEVICE, None]:
248248
raise ValueError(f"Unsupported device {device!r}")
249249
return Array._new(np.fft.fftfreq(n, d=d))
250250

@@ -254,7 +254,7 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A
254254
255255
See its docstring for more information.
256256
"""
257-
if device not in ["cpu", None]:
257+
if device not in [CPU_DEVICE, None]:
258258
raise ValueError(f"Unsupported device {device!r}")
259259
return Array._new(np.fft.rfftfreq(n, d=d))
260260

0 commit comments

Comments
 (0)