Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions array_api_strict/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ def conj(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _complex_floating_dtypes:
raise TypeError("Only complex floating-point dtypes are allowed in conj")
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in conj")
return Array._new(np.conj(x._array), device=x.device)


Expand Down Expand Up @@ -568,8 +568,8 @@ def real(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _complex_floating_dtypes:
raise TypeError("Only complex floating-point dtypes are allowed in real")
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in real")
return Array._new(np.real(x._array), device=x.device)


Expand Down
4 changes: 2 additions & 2 deletions array_api_strict/tests/test_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def nargs(func):
"bitwise_xor": "integer or boolean",
"ceil": "real numeric",
"clip": "real numeric",
"conj": "complex floating-point",
"conj": "numeric",
"copysign": "real floating-point",
"cos": "floating-point",
"cosh": "floating-point",
Expand Down Expand Up @@ -88,7 +88,7 @@ def nargs(func):
"not_equal": "all",
"positive": "numeric",
"pow": "numeric",
"real": "complex floating-point",
"real": "numeric",
"reciprocal": "floating-point",
"remainder": "real numeric",
"round": "numeric",
Expand Down
Loading