Skip to content

Commit 0590a92

Browse files
npolina4antonwolfy
andauthored
Implement dpnp.isreal, dpnp.isrealobj, dpnp.iscomplex, dpnp.iscomplexobj. (#1916)
* Implement dpnp.isreal, dpnp.iscomplex * Applied review comments * Applied review comments * Simplified isreal and iscomplex tests --------- Co-authored-by: Anton <[email protected]>
1 parent fe1da4d commit 0590a92

File tree

6 files changed

+349
-20
lines changed

6 files changed

+349
-20
lines changed

dpnp/dpnp_iface.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
"synchronize_array_data",
7777
]
7878

79-
from dpnp import float64, isscalar
79+
from dpnp import float64
8080
from dpnp.dpnp_iface_arraycreation import *
8181
from dpnp.dpnp_iface_arraycreation import __all__ as __all__arraycreation
8282
from dpnp.dpnp_iface_bitwise import *
@@ -533,7 +533,7 @@ def get_dpnp_descriptor(
533533

534534
# If input object is a scalar, it means it was allocated on host memory.
535535
# We need to copy it to USM memory according to compute follows data.
536-
if isscalar(ext_obj):
536+
if dpnp.isscalar(ext_obj):
537537
ext_obj = array(
538538
ext_obj,
539539
dtype=alloc_dtype,
@@ -743,7 +743,7 @@ def get_usm_ndarray_or_scalar(a):
743743
744744
"""
745745

746-
return a if isscalar(a) else get_usm_ndarray(a)
746+
return a if dpnp.isscalar(a) else get_usm_ndarray(a)
747747

748748

749749
def is_supported_array_or_scalar(a):
@@ -765,7 +765,7 @@ def is_supported_array_or_scalar(a):
765765
766766
"""
767767

768-
return isscalar(a) or is_supported_array_type(a)
768+
return dpnp.isscalar(a) or is_supported_array_type(a)
769769

770770

771771
def is_supported_array_type(a):

dpnp/dpnp_iface_logic.py

Lines changed: 211 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,16 @@
6363
"greater",
6464
"greater_equal",
6565
"isclose",
66+
"iscomplex",
67+
"iscomplexobj",
6668
"isfinite",
6769
"isinf",
6870
"isnan",
6971
"isneginf",
7072
"isposinf",
73+
"isreal",
74+
"isrealobj",
75+
"isscalar",
7176
"less",
7277
"less_equal",
7378
"logical_and",
@@ -233,7 +238,7 @@ def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, **kwargs):
233238
234239
"""
235240

236-
if dpnp.isscalar(a) and dpnp.isscalar(b):
241+
if isscalar(a) and isscalar(b):
237242
# at least one of inputs has to be an array
238243
pass
239244
elif not (
@@ -244,18 +249,18 @@ def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, **kwargs):
244249
elif kwargs:
245250
pass
246251
else:
247-
if not dpnp.isscalar(rtol):
252+
if not isscalar(rtol):
248253
raise TypeError(
249254
f"An argument `rtol` must be a scalar, but got {rtol}"
250255
)
251-
if not dpnp.isscalar(atol):
256+
if not isscalar(atol):
252257
raise TypeError(
253258
f"An argument `atol` must be a scalar, but got {atol}"
254259
)
255260

256-
if dpnp.isscalar(a):
261+
if isscalar(a):
257262
a = dpnp.full_like(b, fill_value=a)
258-
elif dpnp.isscalar(b):
263+
elif isscalar(b):
259264
b = dpnp.full_like(a, fill_value=b)
260265
elif a.shape != b.shape:
261266
a, b = dpt.broadcast_arrays(a.get_array(), b.get_array())
@@ -610,6 +615,90 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
610615
)
611616

612617

618+
def iscomplex(x):
619+
"""
620+
Returns a bool array, where ``True`` if input element is complex.
621+
622+
What is tested is whether the input has a non-zero imaginary part, not if
623+
the input type is complex.
624+
625+
For full documentation refer to :obj:`numpy.iscomplex`.
626+
627+
Parameters
628+
----------
629+
x : {dpnp.ndarray, usm_ndarray}
630+
Input array.
631+
632+
Returns
633+
-------
634+
out : dpnp.ndarray
635+
Output array.
636+
637+
See Also
638+
--------
639+
:obj:`dpnp.isreal` : Returns a bool array, where ``True`` if input element
640+
is real.
641+
:obj:`dpnp.iscomplexobj` : Return ``True`` if `x` is a complex type or an
642+
array of complex numbers.
643+
644+
Examples
645+
--------
646+
>>> import dpnp as np
647+
>>> a = np.array([1+1j, 1+0j, 4.5, 3, 2, 2j])
648+
>>> np.iscomplex(a)
649+
array([ True, False, False, False, False, True])
650+
651+
"""
652+
dpnp.check_supported_arrays_type(x)
653+
if dpnp.issubdtype(x.dtype, dpnp.complexfloating):
654+
return x.imag != 0
655+
return dpnp.zeros_like(x, dtype=dpnp.bool)
656+
657+
658+
def iscomplexobj(x):
659+
"""
660+
Check for a complex type or an array of complex numbers.
661+
662+
The type of the input is checked, not the value. Even if the input has an
663+
imaginary part equal to zero, :obj:`dpnp.iscomplexobj` evaluates to
664+
``True``.
665+
666+
For full documentation refer to :obj:`numpy.iscomplexobj`.
667+
668+
Parameters
669+
----------
670+
x : array_like
671+
Input data, in any form that can be converted to an array. This
672+
includes scalars, lists, lists of tuples, tuples, tuples of tuples,
673+
tuples of lists, and ndarrays.
674+
675+
Returns
676+
-------
677+
out : bool
678+
The return value, ``True`` if `x` is of a complex type or has at least
679+
one complex element.
680+
681+
See Also
682+
--------
683+
:obj:`dpnp.isrealobj` : Return ``True`` if `x` is a not complex type or an
684+
array of complex numbers.
685+
:obj:`dpnp.iscomplex` : Returns a bool array, where ``True`` if input
686+
element is complex.
687+
688+
Examples
689+
--------
690+
>>> import dpnp as np
691+
>>> np.iscomplexobj(1)
692+
False
693+
>>> np.iscomplexobj(1+0j)
694+
True
695+
>>> np.iscomplexobj([3, 1+0j, True])
696+
True
697+
698+
"""
699+
return numpy.iscomplexobj(x)
700+
701+
613702
_ISFINITE_DOCSTRING = """
614703
Test if each element of input array is a finite number.
615704
@@ -923,6 +1012,123 @@ def isposinf(x, out=None):
9231012
return dpnp.logical_and(is_inf, signbit, out=out)
9241013

9251014

1015+
def isreal(x):
1016+
"""
1017+
Returns a bool array, where ``True`` if input element is real.
1018+
1019+
If element has complex type with zero imaginary part, the return value
1020+
for that element is ``True``.
1021+
1022+
For full documentation refer to :obj:`numpy.isreal`.
1023+
1024+
Parameters
1025+
----------
1026+
x : {dpnp.ndarray, usm_ndarray}
1027+
Input array.
1028+
1029+
Returns
1030+
-------
1031+
out : : dpnp.ndarray
1032+
Boolean array of same shape as `x`.
1033+
1034+
See Also
1035+
--------
1036+
:obj:`dpnp.iscomplex` : Returns a bool array, where ``True`` if input
1037+
element is complex.
1038+
:obj:`dpnp.isrealobj` : Return ``True`` if `x` is not a complex type.
1039+
1040+
Examples
1041+
--------
1042+
>>> import dpnp as np
1043+
>>> a = np.array([1+1j, 1+0j, 4.5, 3, 2, 2j])
1044+
>>> np.isreal(a)
1045+
array([False, True, True, True, True, False])
1046+
1047+
"""
1048+
dpnp.check_supported_arrays_type(x)
1049+
if dpnp.issubdtype(x.dtype, dpnp.complexfloating):
1050+
return x.imag == 0
1051+
return dpnp.ones_like(x, dtype=dpnp.bool)
1052+
1053+
1054+
def isrealobj(x):
1055+
"""
1056+
Return ``True`` if `x` is a not complex type or an array of complex numbers.
1057+
1058+
The type of the input is checked, not the value. So even if the input has
1059+
an imaginary part equal to zero, :obj:`dpnp.isrealobj` evaluates to
1060+
``False`` if the data type is complex.
1061+
1062+
For full documentation refer to :obj:`numpy.isrealobj`.
1063+
1064+
Parameters
1065+
----------
1066+
x : array_like
1067+
Input data, in any form that can be converted to an array. This
1068+
includes scalars, lists, lists of tuples, tuples, tuples of tuples,
1069+
tuples of lists, and ndarrays.
1070+
1071+
Returns
1072+
-------
1073+
out : bool
1074+
The return value, ``False`` if `x` is of a complex type.
1075+
1076+
See Also
1077+
--------
1078+
:obj:`dpnp.iscomplexobj` : Check for a complex type or an array of complex
1079+
numbers.
1080+
:obj:`dpnp.isreal` : Returns a bool array, where ``True`` if input element
1081+
is real.
1082+
1083+
Examples
1084+
--------
1085+
>>> import dpnp as np
1086+
>>> np.isrealobj(False)
1087+
True
1088+
>>> np.isrealobj(1)
1089+
True
1090+
>>> np.isrealobj(1+0j)
1091+
False
1092+
>>> np.isrealobj([3, 1+0j, True])
1093+
False
1094+
1095+
"""
1096+
return not iscomplexobj(x)
1097+
1098+
1099+
def isscalar(element):
1100+
"""
1101+
Returns ``True`` if the type of `element` is a scalar type.
1102+
1103+
For full documentation refer to :obj:`numpy.isscalar`.
1104+
1105+
Parameters
1106+
----------
1107+
element : any
1108+
Input argument, can be of any type and shape.
1109+
1110+
Returns
1111+
-------
1112+
out : bool
1113+
``True`` if `element` is a scalar type, ``False`` if it is not.
1114+
1115+
Examples
1116+
--------
1117+
>>> import dpnp as np
1118+
>>> np.isscalar(3.1)
1119+
True
1120+
>>> np.isscalar(np.array(3.1))
1121+
False
1122+
>>> np.isscalar([3.1])
1123+
False
1124+
>>> np.isscalar(False)
1125+
True
1126+
>>> np.isscalar("dpnp")
1127+
True
1128+
"""
1129+
return numpy.isscalar(element)
1130+
1131+
9261132
_LESS_DOCSTRING = """
9271133
Computes the less-than test results for each element `x1_i` of
9281134
the input array `x1` with the respective element `x2_i` of the input array `x2`.

dpnp/dpnp_iface_types.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
"integer",
7272
"intc",
7373
"intp",
74-
"isscalar",
7574
"issubdtype",
7675
"issubsctype",
7776
"is_type_supported",
@@ -229,16 +228,6 @@ def iinfo(dtype):
229228
return dpt.iinfo(dtype)
230229

231230

232-
def isscalar(obj):
233-
"""
234-
Returns ``True`` if the type of `obj` is a scalar type.
235-
236-
For full documentation refer to :obj:`numpy.isscalar`.
237-
238-
"""
239-
return numpy.isscalar(obj)
240-
241-
242231
def issubdtype(arg1, arg2):
243232
"""
244233
Returns ``True`` if the first argument is a type code lower/equal

tests/test_sycl_queue.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,8 @@ def test_meshgrid(device):
427427
pytest.param(
428428
"imag", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)]
429429
),
430+
pytest.param("iscomplex", [1 + 1j, 1 + 0j, 4.5, 3, 2, 2j]),
431+
pytest.param("isreal", [1 + 1j, 1 + 0j, 4.5, 3, 2, 2j]),
430432
pytest.param("log", [1.0, 2.0, 4.0, 7.0]),
431433
pytest.param("log10", [1.0, 2.0, 4.0, 7.0]),
432434
pytest.param("log1p", [1.0e-10, 1.0, 2.0, 4.0, 7.0]),

tests/test_usm_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,8 @@ def test_norm(usm_type, ord, axis):
559559
pytest.param(
560560
"imag", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)]
561561
),
562+
pytest.param("iscomplex", [1 + 1j, 1 + 0j, 4.5, 3, 2, 2j]),
563+
pytest.param("isreal", [1 + 1j, 1 + 0j, 4.5, 3, 2, 2j]),
562564
pytest.param("log", [1.0, 2.0, 4.0, 7.0]),
563565
pytest.param("log10", [1.0, 2.0, 4.0, 7.0]),
564566
pytest.param("log1p", [1.0e-10, 1.0, 2.0, 4.0, 7.0]),

0 commit comments

Comments
 (0)