Skip to content

Commit f549605

Browse files
authored
implement convert_single_elem_array_to_scalar() function (#765)
1 parent d7aa8b8 commit f549605

File tree

6 files changed

+48
-76
lines changed

6 files changed

+48
-76
lines changed

dpnp/dpnp_iface.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
__all__ = [
5555
"array_equal",
5656
"asnumpy",
57+
"convert_single_elem_array_to_scalar",
5758
"dpnp_queue_initialize",
5859
"dpnp_queue_is_cpu",
5960
"get_dpnp_descriptor",
@@ -134,6 +135,16 @@ def asnumpy(input, order='C'):
134135
return numpy.asarray(input, order=order)
135136

136137

138+
def convert_single_elem_array_to_scalar(obj):
139+
"""
140+
Convert array with single element to scalar
141+
"""
142+
143+
if obj.shape == (1,): # TODO handle shapes like (1,1,1,1)
144+
return obj.dtype.type(obj[0])
145+
146+
return obj
147+
137148
def get_dpnp_descriptor(ext_obj):
138149
"""
139150
Return True:

dpnp/dpnp_iface_counting.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@
4040
"""
4141

4242

43-
import numpy
44-
4543
from dpnp.dpnp_algo.dpnp_algo import * # TODO need to investigate why dpnp.dpnp_algo can not be used
4644
from dpnp.dparray import dparray
4745

4846
# full module name because dpnp_iface_counting loaded from cython too early
4947
from dpnp.dpnp_utils.dpnp_algo_utils import *
5048

49+
import dpnp
50+
import numpy
51+
5152
__all__ = [
5253
'count_nonzero'
5354
]
@@ -84,11 +85,8 @@ def count_nonzero(in_array1, axis=None, *, keepdims=False):
8485
if keepdims is not False:
8586
checker_throw_value_error("count_nonzero", "keepdims", keepdims, False)
8687

87-
result = dpnp_count_nonzero(in_array1)
88-
89-
# scalar returned
90-
if result.shape == (1,):
91-
return result.dtype.type(result[0])
88+
result_obj = dpnp_count_nonzero(in_array1)
89+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
9290

9391
return result
9492

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,8 @@ def dot(x1, x2, **kwargs):
9898
dim2 = x2_desc.ndim
9999

100100
if not (dim1 >= 2 and dim2 == 1) and not (dim1 >= 2 and dim2 >= 2) and (x1_desc.dtype == x2_desc.dtype):
101-
result = dpnp_dot(x1_desc, x2_desc)
102-
103-
# scalar returned
104-
if result.shape == (1,):
105-
return result.dtype.type(result[0])
101+
result_obj = dpnp_dot(x1_desc, x2_desc)
102+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
106103

107104
return result
108105

dpnp/dpnp_iface_searching.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@
4040
"""
4141

4242

43-
import numpy
44-
4543
from dpnp.dpnp_algo import *
4644
from dpnp.dparray import dparray
4745
from dpnp.dpnp_utils import *
4846

47+
import dpnp
48+
import numpy
49+
4950

5051
__all__ = [
5152
'argmax',
@@ -101,11 +102,8 @@ def argmax(in_array1, axis=None, out=None):
101102
if out is not None:
102103
checker_throw_value_error("argmax", "out", type(out), None)
103104

104-
result = dpnp_argmax(in_array1)
105-
106-
# scalar returned
107-
if result.shape == (1,):
108-
return result.dtype.type(result[0])
105+
result_obj = dpnp_argmax(in_array1)
106+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
109107

110108
return result
111109

@@ -160,11 +158,8 @@ def argmin(in_array1, axis=None, out=None):
160158
if out is not None:
161159
checker_throw_value_error("argmin", "out", type(out), None)
162160

163-
result = dpnp_argmin(in_array1)
164-
165-
# scalar returned
166-
if result.shape == (1,):
167-
return result.dtype.type(result[0])
161+
result_obj = dpnp_argmin(in_array1)
162+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
168163

169164
return result
170165

dpnp/dpnp_iface_statistics.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,10 @@ def average(a, axis=None, weights=None, returned=False):
162162
elif returned:
163163
pass
164164
else:
165-
array_avg = dpnp_average(a)
165+
result_obj = dpnp_average(a)
166+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
166167

167-
# scalar returned
168-
if array_avg.shape == (1,):
169-
return array_avg.dtype.type(array_avg[0])
170-
171-
return array_avg
168+
return result
172169

173170
return call_origin(numpy.average, a, axis, weights, returned)
174171

@@ -335,11 +332,8 @@ def max(input, axis=None, out=None, keepdims=numpy._NoValue, initial=numpy._NoVa
335332
elif where is not numpy._NoValue:
336333
pass
337334
else:
338-
result = dpnp_max(input, axis=axis)
339-
340-
# scalar returned
341-
if result.shape == (1,):
342-
return result.dtype.type(result[0])
335+
result_obj = dpnp_max(input, axis=axis)
336+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
343337

344338
return result
345339

@@ -386,11 +380,8 @@ def mean(a, axis=None, **kwargs):
386380
elif a.size == 0:
387381
pass
388382
else:
389-
result = dpnp_mean(a, axis=axis)
390-
391-
# scalar returned
392-
if result.shape == (1,):
393-
return result.dtype.type(result[0])
383+
result_obj = dpnp_mean(a, axis=axis)
384+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
394385

395386
return result
396387

@@ -439,11 +430,8 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
439430
elif keepdims:
440431
pass
441432
else:
442-
result = dpnp_median(a)
443-
444-
# scalar returned
445-
if result.shape == (1,):
446-
return result.dtype.type(result[0])
433+
result_obj = dpnp_median(a)
434+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
447435

448436
return result
449437

@@ -486,11 +474,8 @@ def min(input, axis=None, out=None, keepdims=numpy._NoValue, initial=numpy._NoVa
486474
elif where is not numpy._NoValue:
487475
pass
488476
else:
489-
result = dpnp_min(input, axis=axis)
490-
491-
# scalar returned
492-
if result.shape == (1,):
493-
return result.dtype.type(result[0])
477+
result_obj = dpnp_min(input, axis=axis)
478+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
494479

495480
return result
496481

@@ -524,11 +509,8 @@ def nanvar(arr, axis=None, dtype=None, out=None, ddof=0, keepdims=numpy._NoValue
524509
elif keepdims is not numpy._NoValue:
525510
pass
526511
else:
527-
result = dpnp_nanvar(arr, ddof)
528-
529-
# scalar returned
530-
if result.shape == (1,):
531-
return result.dtype.type(result[0])
512+
result_obj = dpnp_nanvar(arr, ddof)
513+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
532514

533515
return result
534516

@@ -586,9 +568,8 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=numpy._NoValue):
586568
elif keepdims is not numpy._NoValue:
587569
pass
588570
else:
589-
result = dpnp_std(a, ddof)
590-
if result.shape == (1,):
591-
return result.dtype.type(result[0])
571+
result_obj = dpnp_std(a, ddof)
572+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
592573

593574
return result
594575

@@ -646,9 +627,8 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=numpy._NoValue):
646627
elif keepdims is not numpy._NoValue:
647628
pass
648629
else:
649-
result = dpnp_var(a, ddof)
650-
if result.shape == (1,):
651-
return result.dtype.type(result[0])
630+
result_obj = dpnp_var(a, ddof)
631+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
652632

653633
return result
654634

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,8 @@ def det(input):
146146

147147
if not use_origin_backend(input) and is_input_dparray:
148148
if input.shape[-1] == input.shape[-2]:
149-
result = dpnp_det(input)
150-
151-
# scalar returned
152-
if result.shape == (1,):
153-
return result.dtype.type(result[0])
149+
result_obj = dpnp_det(input)
150+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
154151

155152
return result
156153

@@ -301,11 +298,8 @@ def matrix_rank(input, tol=None, hermitian=False):
301298
if hermitian is not False:
302299
checker_throw_value_error("matrix_rank", "hermitian", hermitian, False)
303300

304-
result = dpnp_matrix_rank(input)
305-
306-
# scalar returned
307-
if result.shape == (1,):
308-
return result.dtype.type(result[0])
301+
result_obj = dpnp_matrix_rank(input)
302+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
309303

310304
return result
311305

@@ -392,11 +386,8 @@ def norm(input, ord=None, axis=None, keepdims=False):
392386
elif ord not in [None, 0, 3, 'fro', 'f']:
393387
pass
394388
else:
395-
result = dpnp_norm(input, ord=ord, axis=axis)
396-
397-
# scalar returned
398-
if result.shape == (1,) and axis is None:
399-
return result.dtype.type(result[0])
389+
result_obj = dpnp_norm(input, ord=ord, axis=axis)
390+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
400391

401392
return result
402393

0 commit comments

Comments
 (0)