Skip to content

Commit 9699969

Browse files
authored
change greater, greater_equal, less, less_equal (#484)
1 parent f579050 commit 9699969

File tree

4 files changed

+96
-36
lines changed

4 files changed

+96
-36
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,11 @@ cpdef dparray dpnp_right_shift(dparray array1, dparray array2)
198198
Logic functions
199199
"""
200200
cpdef dparray dpnp_equal(dparray array1, input2)
201-
cpdef dparray dpnp_greater(dparray input1, dparray input2)
202-
cpdef dparray dpnp_greater_equal(dparray input1, dparray input2)
201+
cpdef dparray dpnp_greater(dparray input1, input2)
202+
cpdef dparray dpnp_greater_equal(dparray input1, input2)
203203
cpdef dparray dpnp_isclose(dparray input1, input2, double rtol=*, double atol=*, cpp_bool equal_nan=*)
204-
cpdef dparray dpnp_less(dparray input1, dparray input2)
205-
cpdef dparray dpnp_less_equal(dparray input1, dparray input2)
204+
cpdef dparray dpnp_less(dparray input1, input2)
205+
cpdef dparray dpnp_less_equal(dparray input1, input2)
206206
cpdef dparray dpnp_logical_and(dparray input1, dparray input2)
207207
cpdef dparray dpnp_logical_not(dparray input1)
208208
cpdef dparray dpnp_logical_or(dparray input1, dparray input2)

dpnp/dpnp_algo/dpnp_algo_logic.pyx

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,32 @@ cpdef dparray dpnp_equal(dparray array1, input2):
9797
return result
9898

9999

100-
cpdef dparray dpnp_greater(dparray input1, dparray input2):
100+
cpdef dparray dpnp_greater(dparray input1, input2):
101+
input2_is_scalar = dpnp.isscalar(input2)
102+
101103
cdef dparray result = dparray(input1.shape, dtype=numpy.bool)
102104

103-
for i in range(result.size):
104-
result[i] = numpy.bool(input1[i] > input2[i])
105+
if input2_is_scalar:
106+
for i in range(result.size):
107+
result[i] = dpnp.bool(input1[i] > input2)
108+
else:
109+
for i in range(result.size):
110+
result[i] = dpnp.bool(input1[i] > input2[i])
105111

106112
return result
107113

108114

109-
cpdef dparray dpnp_greater_equal(dparray input1, dparray input2):
115+
cpdef dparray dpnp_greater_equal(dparray input1, input2):
116+
input2_is_scalar = dpnp.isscalar(input2)
117+
110118
cdef dparray result = dparray(input1.shape, dtype=numpy.bool)
111119

112-
for i in range(result.size):
113-
result[i] = numpy.bool(input1[i] >= input2[i])
120+
if input2_is_scalar:
121+
for i in range(result.size):
122+
result[i] = dpnp.bool(input1[i] >= input2)
123+
else:
124+
for i in range(result.size):
125+
result[i] = dpnp.bool(input1[i] >= input2[i])
114126

115127
return result
116128

@@ -155,20 +167,32 @@ cpdef dparray dpnp_isnan(dparray input1):
155167
return result
156168

157169

158-
cpdef dparray dpnp_less(dparray input1, dparray input2):
170+
cpdef dparray dpnp_less(dparray input1, input2):
171+
input2_is_scalar = dpnp.isscalar(input2)
172+
159173
cdef dparray result = dparray(input1.shape, dtype=numpy.bool)
160174

161-
for i in range(result.size):
162-
result[i] = numpy.bool(input1[i] < input2[i])
175+
if input2_is_scalar:
176+
for i in range(result.size):
177+
result[i] = dpnp.bool(input1[i] < input2)
178+
else:
179+
for i in range(result.size):
180+
result[i] = dpnp.bool(input1[i] < input2[i])
163181

164182
return result
165183

166184

167-
cpdef dparray dpnp_less_equal(dparray input1, dparray input2):
185+
cpdef dparray dpnp_less_equal(dparray input1, input2):
186+
input2_is_scalar = dpnp.isscalar(input2)
187+
168188
cdef dparray result = dparray(input1.shape, dtype=numpy.bool)
169189

170-
for i in range(result.size):
171-
result[i] = numpy.bool(input1[i] <= input2[i])
190+
if input2_is_scalar:
191+
for i in range(result.size):
192+
result[i] = dpnp.bool(input1[i] <= input2)
193+
else:
194+
for i in range(result.size):
195+
result[i] = dpnp.bool(input1[i] <= input2[i])
172196

173197
return result
174198

dpnp/dpnp_iface_logic.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,11 @@ def greater(x1, x2):
269269
270270
"""
271271

272-
if (use_origin_backend(x1)):
273-
return numpy.greater(x1, x2)
274-
275-
if isinstance(x1, dparray) or isinstance(x2, dparray):
276-
return dpnp_greater(x1, x2)
272+
if not (use_origin_backend(x1)):
273+
if not isinstance(x1, dparray):
274+
pass
275+
else:
276+
return dpnp_greater(x1, x2)
277277

278278
return numpy.greater(x1, x2)
279279

@@ -309,11 +309,11 @@ def greater_equal(x1, x2):
309309
310310
"""
311311

312-
if (use_origin_backend(x1)):
313-
return numpy.greater_equal(x1, x2)
314-
315-
if isinstance(x1, dparray) or isinstance(x2, dparray):
316-
return dpnp_greater_equal(x1, x2)
312+
if not (use_origin_backend(x1)):
313+
if not isinstance(x1, dparray):
314+
pass
315+
else:
316+
return dpnp_greater_equal(x1, x2)
317317

318318
return numpy.greater_equal(x1, x2)
319319

@@ -545,11 +545,11 @@ def less(x1, x2):
545545
546546
"""
547547

548-
if (use_origin_backend(x1)):
549-
return numpy.less(x1, x2)
550-
551-
if isinstance(x1, dparray) or isinstance(x2, dparray):
552-
return dpnp_less(x1, x2)
548+
if not (use_origin_backend(x1)):
549+
if not isinstance(x1, dparray):
550+
pass
551+
else:
552+
return dpnp_less(x1, x2)
553553

554554
return numpy.less(x1, x2)
555555

@@ -585,11 +585,11 @@ def less_equal(x1, x2):
585585
586586
"""
587587

588-
if (use_origin_backend(x1)):
589-
return numpy.less_equal(x1, x2)
590-
591-
if isinstance(x1, dparray) or isinstance(x2, dparray):
592-
return dpnp_less_equal(x1, x2)
588+
if not (use_origin_backend(x1)):
589+
if not isinstance(x1, dparray):
590+
pass
591+
else:
592+
return dpnp_less_equal(x1, x2)
593593

594594
return numpy.less_equal(x1, x2)
595595

tests/test_logic.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,42 @@ def test_any(type, shape):
7171
numpy.testing.assert_allclose(dpnp_res, np_res)
7272

7373

74+
def test_greater():
75+
a = numpy.array([1, 2, 3, 4, 5, 6, 7, 8])
76+
ia = dpnp.array(a)
77+
for i in range(len(a) + 1):
78+
np_res = (a > i)
79+
dpnp_res = (ia > i)
80+
numpy.testing.assert_equal(dpnp_res, np_res)
81+
82+
83+
def test_greater_equal():
84+
a = numpy.array([1, 2, 3, 4, 5, 6, 7, 8])
85+
ia = dpnp.array(a)
86+
for i in range(len(a) + 1):
87+
np_res = (a >= i)
88+
dpnp_res = (ia >= i)
89+
numpy.testing.assert_equal(dpnp_res, np_res)
90+
91+
92+
def test_less():
93+
a = numpy.array([1, 2, 3, 4, 5, 6, 7, 8])
94+
ia = dpnp.array(a)
95+
for i in range(len(a) + 1):
96+
np_res = (a < i)
97+
dpnp_res = (ia < i)
98+
numpy.testing.assert_equal(dpnp_res, np_res)
99+
100+
101+
def test_less_equal():
102+
a = numpy.array([1, 2, 3, 4, 5, 6, 7, 8])
103+
ia = dpnp.array(a)
104+
for i in range(len(a) + 1):
105+
np_res = (a <= i)
106+
dpnp_res = (ia <= i)
107+
numpy.testing.assert_equal(dpnp_res, np_res)
108+
109+
74110
def test_not_equal():
75111
a = numpy.array([1, 2, 3, 4, 5, 6, 7, 8])
76112
ia = dpnp.array(a)

0 commit comments

Comments
 (0)