Skip to content

Commit 457c79e

Browse files
authored
replace conv2scalar function (#766)
1 parent de7cf34 commit 457c79e

File tree

3 files changed

+14
-18
lines changed

3 files changed

+14
-18
lines changed

dpnp/dpnp_iface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,17 @@ def asnumpy(input, order='C'):
135135
return numpy.asarray(input, order=order)
136136

137137

138-
def convert_single_elem_array_to_scalar(obj):
138+
def convert_single_elem_array_to_scalar(obj, keepdims=False):
139139
"""
140140
Convert array with single element to scalar
141141
"""
142142

143-
if obj.shape == (1,): # TODO handle shapes like (1,1,1,1)
143+
if (obj.ndim > 0) and (obj.size == 1) and (keepdims is False):
144144
return obj.dtype.type(obj[0])
145145

146146
return obj
147147

148+
148149
def get_dpnp_descriptor(ext_obj):
149150
"""
150151
Return True:

dpnp/dpnp_iface_logic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def not_equal(x1, x2):
761761
x2_desc = dpnp.get_dpnp_descriptor(x2)
762762
is_x1_scalar = dpnp.isscalar(x1)
763763
is_x2_scalar = dpnp.isscalar(x2)
764-
if 0 and (x1_desc and x2_desc and (x1_desc or is_x1_scalar)) and (not use_origin_backend(x2) and (x2_desc or is_x2_scalar)) and not(is_x1_scalar and is_x2_scalar):
764+
if 0 and (x1_desc and x2_desc and (x1_desc or is_x1_scalar)) and (x2_desc or is_x2_scalar) and not(is_x1_scalar and is_x2_scalar):
765765
if is_x1_scalar:
766766
result = dpnp_not_equal(x2_desc, x1_desc)
767767
else:

dpnp/dpnp_iface_mathematical.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@
4040
"""
4141

4242

43-
import numpy
44-
4543
from dpnp.dpnp_algo import *
4644
from dpnp.dparray import dparray
4745
from dpnp.dpnp_utils import *
46+
4847
import dpnp
48+
import numpy
4949

5050

5151
__all__ = [
@@ -93,15 +93,6 @@
9393
]
9494

9595

96-
def convert_result_scalar(result, keepdims):
97-
# one element array result should be converted into scalar
98-
# TODO empty shape must be converted into scalar (it is not in test system)
99-
if (len(result.shape) > 0) and (result.size == 1) and (keepdims is False):
100-
return result.dtype.type(result[0])
101-
else:
102-
return result
103-
104-
10596
def abs(*args, **kwargs):
10697
"""
10798
Calculate the absolute value element-wise.
@@ -1377,8 +1368,10 @@ def prod(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, wher
13771368
elif where is not True:
13781369
pass
13791370
else:
1380-
result = dpnp_prod(x1, axis, dtype, out, keepdims, initial, where)
1381-
return convert_result_scalar(result, keepdims)
1371+
result_obj = dpnp_prod(x1, axis, dtype, out, keepdims, initial, where)
1372+
result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims)
1373+
1374+
return result
13821375

13831376
return call_origin(numpy.prod, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
13841377

@@ -1555,8 +1548,10 @@ def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where
15551548
elif where is not True:
15561549
pass
15571550
else:
1558-
result = dpnp_sum(x1, axis, dtype, out, keepdims, initial, where)
1559-
return convert_result_scalar(result, keepdims)
1551+
result_obj = dpnp_sum(x1, axis, dtype, out, keepdims, initial, where)
1552+
result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims)
1553+
1554+
return result
15601555

15611556
return call_origin(numpy.sum, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
15621557

0 commit comments

Comments
 (0)