Skip to content

Commit 0db29d3

Browse files
bool dtype support, all and any functions (#77)
* bool dtype support * all, any funcs
1 parent 0548a72 commit 0db29d3

File tree

7 files changed

+307
-6
lines changed

7 files changed

+307
-6
lines changed

dpnp/backend.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# *****************************************************************************
2727

2828
from libcpp.vector cimport vector
29-
from libcpp cimport bool
29+
from libcpp cimport bool as cpp_bool
3030
from dpnp.dparray cimport dparray, dparray_shape_type
3131

3232
cdef extern from "backend/backend_iface_fptr.hpp" namespace "DPNPFuncName": # need this namespace for Enum import
@@ -156,7 +156,7 @@ Logic functions
156156
cpdef dparray dpnp_equal(dparray array1, input2)
157157
cpdef dparray dpnp_greater(dparray input1, dparray input2)
158158
cpdef dparray dpnp_greater_equal(dparray input1, dparray input2)
159-
cpdef dparray dpnp_isclose(dparray input1, input2, double rtol=*, double atol=*, bool equal_nan=*)
159+
cpdef dparray dpnp_isclose(dparray input1, input2, double rtol=*, double atol=*, cpp_bool equal_nan=*)
160160
cpdef dparray dpnp_less(dparray input1, dparray input2)
161161
cpdef dparray dpnp_less_equal(dparray input1, dparray input2)
162162
cpdef dparray dpnp_logical_and(dparray input1, dparray input2)

dpnp/backend_logic.pyx

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ from dpnp.dpnp_utils cimport *
3737

3838

3939
__all__ += [
40+
"dpnp_all",
41+
"dpnp_any",
4042
"dpnp_equal",
4143
"dpnp_greater",
4244
"dpnp_greater_equal",
@@ -54,6 +56,34 @@ __all__ += [
5456
]
5557

5658

59+
cpdef dparray dpnp_all(dparray array1):
60+
cdef dparray result = dparray((1,), dtype=numpy.bool)
61+
62+
res = True
63+
for i in range(array1.size):
64+
if not numpy.bool(array1[i]):
65+
res = False
66+
break
67+
68+
result[0] = res
69+
70+
return result
71+
72+
73+
cpdef dparray dpnp_any(dparray array1):
74+
cdef dparray result = dparray((1,), dtype=numpy.bool)
75+
76+
res = False
77+
for i in range(array1.size):
78+
if numpy.bool(array1[i]):
79+
res = True
80+
break
81+
82+
result[0] = res
83+
84+
return result
85+
86+
5787
cpdef dparray dpnp_equal(dparray array1, input2):
5888
cdef dparray result = dparray(array1.shape, dtype=numpy.bool)
5989

@@ -85,7 +115,7 @@ cpdef dparray dpnp_greater_equal(dparray input1, dparray input2):
85115
return result
86116

87117

88-
cpdef dparray dpnp_isclose(dparray input1, input2, double rtol=1e-05, double atol=1e-08, bool equal_nan=False):
118+
cpdef dparray dpnp_isclose(dparray input1, input2, double rtol=1e-05, double atol=1e-08, cpp_bool equal_nan=False):
89119
cdef dparray result = dparray(input1.shape, dtype=numpy.bool)
90120

91121
if isinstance(input2, int):

dpnp/dparray.pyx

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@ using USB interface for an Intel GPU device.
3434
"""
3535

3636

37-
from libcpp cimport bool
37+
from libcpp cimport bool as cpp_bool
3838

3939
from dpnp.dpnp_iface_types import *
4040
from dpnp.dpnp_iface import *
4141
from dpnp.backend cimport *
4242
from dpnp.dpnp_iface_statistics import min, max
43+
from dpnp.dpnp_iface_logic import all, any
4344
import numpy
4445
cimport numpy
4546

@@ -346,7 +347,7 @@ cdef class dparray:
346347
elif self.dtype == numpy.int32:
347348
return (< int * > self._dparray_data)[lin_idx]
348349
elif self.dtype == numpy.bool:
349-
return (< bool * > self._dparray_data)[lin_idx]
350+
return (< cpp_bool * > self._dparray_data)[lin_idx]
350351

351352
utils.checker_throw_type_error("__getitem__", self.dtype)
352353

@@ -371,7 +372,7 @@ cdef class dparray:
371372
elif self.dtype == numpy.int32:
372373
(< int * > self._dparray_data)[lin_idx] = <int > value
373374
elif self.dtype == numpy.bool:
374-
(< bool * > self._dparray_data)[lin_idx] = <bool > value
375+
(< cpp_bool * > self._dparray_data)[lin_idx] = < cpp_bool > value
375376
else:
376377
utils.checker_throw_type_error("__setitem__", self.dtype)
377378

@@ -1029,6 +1030,40 @@ cdef class dparray:
10291030

10301031
return argmin(self, axis, out)
10311032

1033+
"""
1034+
-------------------------------------------------------------------------
1035+
Logic
1036+
-------------------------------------------------------------------------
1037+
"""
1038+
1039+
def all(self, axis=None, out=None, keepdims=False):
1040+
"""
1041+
Returns True if all elements evaluate to True.
1042+
1043+
Refer to `numpy.all` for full documentation.
1044+
1045+
See Also
1046+
--------
1047+
numpy.all : equivalent function
1048+
1049+
"""
1050+
1051+
return all(self, axis, out, keepdims)
1052+
1053+
def any(self, axis=None, out=None, keepdims=False):
1054+
"""
1055+
Returns True if any of the elements of `a` evaluate to True.
1056+
1057+
Refer to `numpy.any` for full documentation.
1058+
1059+
See Also
1060+
--------
1061+
numpy.any : equivalent function
1062+
1063+
"""
1064+
1065+
return any(self, axis, out, keepdims)
1066+
10321067
"""
10331068
-------------------------------------------------------------------------
10341069
Other attributes

dpnp/dpnp_iface_logic.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@
4141

4242

4343
import numpy
44+
import dpnp
4445

4546
from dpnp.backend import *
4647
from dpnp.dparray import dparray
4748
from dpnp.dpnp_utils import *
4849

4950

5051
__all__ = [
52+
"all",
53+
"any",
5154
"equal",
5255
"greater",
5356
"greater_equal",
@@ -65,6 +68,160 @@
6568
]
6669

6770

71+
def all(in_array1, axis=None, out=None, keepdims=False):
72+
"""
73+
Test whether all array elements along a given axis evaluate to True.
74+
75+
Parameters
76+
----------
77+
a : array_like
78+
Input array or object that can be converted to an array.
79+
axis : None or int or tuple of ints, optional
80+
Axis or axes along which a logical AND reduction is performed.
81+
The default (``axis=None``) is to perform a logical AND over all
82+
the dimensions of the input array. `axis` may be negative, in
83+
which case it counts from the last to the first axis.
84+
85+
.. versionadded:: 1.7.0
86+
87+
If this is a tuple of ints, a reduction is performed on multiple
88+
axes, instead of a single axis or all the axes as before.
89+
out : ndarray, optional
90+
Alternate output array in which to place the result.
91+
It must have the same shape as the expected output and its
92+
type is preserved (e.g., if ``dtype(out)`` is float, the result
93+
will consist of 0.0's and 1.0's). See `ufuncs-output-type` for more
94+
details.
95+
96+
keepdims : bool, optional
97+
If this is set to True, the axes which are reduced are left
98+
in the result as dimensions with size one. With this option,
99+
the result will broadcast correctly against the input array.
100+
101+
If the default value is passed, then `keepdims` will not be
102+
passed through to the `all` method of sub-classes of
103+
`ndarray`, however any non-default value will be. If the
104+
sub-class' method does not implement `keepdims` any
105+
exceptions will be raised.
106+
107+
Returns
108+
-------
109+
all : ndarray, bool
110+
A new boolean or array is returned unless `out` is specified,
111+
in which case a reference to `out` is returned.
112+
113+
See Also
114+
--------
115+
ndarray.all : equivalent method
116+
117+
any : Test whether any element along a given axis evaluates to True.
118+
119+
Notes
120+
-----
121+
Not a Number (NaN), positive infinity and negative infinity
122+
evaluate to `True` because these are not equal to zero.
123+
124+
"""
125+
126+
is_dparray1 = isinstance(in_array1, dparray)
127+
128+
if (not use_origin_backend(in_array1) and is_dparray1):
129+
if axis is not None:
130+
checker_throw_value_error("all", "axis", type(axis), None)
131+
if out is not None:
132+
checker_throw_value_error("all", "out", type(out), None)
133+
if keepdims is not False:
134+
checker_throw_value_error("all", "keepdims", keepdims, False)
135+
136+
result = dpnp_all(in_array1)
137+
138+
# scalar returned
139+
if result.shape == (1,):
140+
return result.dtype.type(result[0])
141+
142+
return result
143+
144+
return call_origin(numpy.all, axis, out, keepdims)
145+
146+
147+
def any(in_array1, axis=None, out=None, keepdims=False):
148+
"""
149+
Test whether any array element along a given axis evaluates to True.
150+
151+
Returns single boolean unless `axis` is not ``None``
152+
153+
Parameters
154+
----------
155+
a : array_like
156+
Input array or object that can be converted to an array.
157+
axis : None or int or tuple of ints, optional
158+
Axis or axes along which a logical OR reduction is performed.
159+
The default (``axis=None``) is to perform a logical OR over all
160+
the dimensions of the input array. `axis` may be negative, in
161+
which case it counts from the last to the first axis.
162+
163+
.. versionadded:: 1.7.0
164+
165+
If this is a tuple of ints, a reduction is performed on multiple
166+
axes, instead of a single axis or all the axes as before.
167+
out : ndarray, optional
168+
Alternate output array in which to place the result. It must have
169+
the same shape as the expected output and its type is preserved
170+
(e.g., if it is of type float, then it will remain so, returning
171+
1.0 for True and 0.0 for False, regardless of the type of `a`).
172+
See `ufuncs-output-type` for more details.
173+
174+
keepdims : bool, optional
175+
If this is set to True, the axes which are reduced are left
176+
in the result as dimensions with size one. With this option,
177+
the result will broadcast correctly against the input array.
178+
179+
If the default value is passed, then `keepdims` will not be
180+
passed through to the `any` method of sub-classes of
181+
`ndarray`, however any non-default value will be. If the
182+
sub-class' method does not implement `keepdims` any
183+
exceptions will be raised.
184+
185+
Returns
186+
-------
187+
any : bool or ndarray
188+
A new boolean or `ndarray` is returned unless `out` is specified,
189+
in which case a reference to `out` is returned.
190+
191+
See Also
192+
--------
193+
ndarray.any : equivalent method
194+
195+
all : Test whether all elements along a given axis evaluate to True.
196+
197+
Notes
198+
-----
199+
Not a Number (NaN), positive infinity and negative infinity evaluate
200+
to `True` because these are not equal to zero.
201+
202+
"""
203+
204+
is_dparray1 = isinstance(in_array1, dparray)
205+
206+
if (not use_origin_backend(in_array1) and is_dparray1):
207+
if axis is not None:
208+
checker_throw_value_error("any", "axis", type(axis), None)
209+
if out is not None:
210+
checker_throw_value_error("any", "out", type(out), None)
211+
if keepdims is not False:
212+
checker_throw_value_error("any", "keepdims", keepdims, False)
213+
214+
result = dpnp_any(in_array1)
215+
216+
# scalar returned
217+
if result.shape == (1,):
218+
return result.dtype.type(result[0])
219+
220+
return result
221+
222+
return call_origin(numpy.any, axis, out, keepdims)
223+
224+
68225
def equal(x1, x2):
69226
"""
70227
Return (x1 == x2) element-wise.

dpnp/dpnp_iface_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
import numpy
3838

3939
__all__ = [
40+
"bool",
41+
"bool_",
4042
"dtype",
4143
"float",
4244
"float32",
@@ -48,6 +50,9 @@
4850
"newaxis"
4951
]
5052

53+
bool = numpy.bool
54+
bool_ = numpy.bool_
55+
5156
dtype = numpy.dtype
5257

5358
float = numpy.float

dpnp/dpnp_utils.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ This module contains differnt helpers and utilities
3131
3232
"""
3333

34+
from libcpp cimport bool as cpp_bool
3435
import dpnp
3536
import dpnp.config as config
3637
import numpy
@@ -138,6 +139,8 @@ cdef long copy_values_to_dparray(dparray dst, input_obj, size_t dst_idx=0) excep
138139
( < long * > dst.get_data())[dst_idx] = elem_value
139140
elif elem_dtype == numpy.int32:
140141
( < int * > dst.get_data())[dst_idx] = elem_value
142+
elif elem_dtype == numpy.bool_ or elem_dtype == numpy.bool:
143+
(< cpp_bool * > dst.get_data())[dst_idx] = elem_value
141144
else:
142145
checker_throw_type_error("copy_values_to_dparray", elem_dtype)
143146

0 commit comments

Comments
 (0)