Skip to content

Commit 85aad6b

Browse files
authored
add parameters to dpnp.prod (#636)
* add parameters to dpnp.sum
1 parent 25b75e3 commit 85aad6b

File tree

6 files changed

+149
-56
lines changed

6 files changed

+149
-56
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ INP_DLLEXPORT void dpnp_cumsum_c(void* array1_in, void* result1, size_t size);
225225
*
226226
* Empty @ref input_shape means scalar.
227227
*
228-
* @param [in] input_in Input array pointer. @ref _DataType_input type is expected
229228
* @param [out] result_out Output array pointer. @ref _DataType_output type is expected
229+
* @param [in] input_in Input array pointer. @ref _DataType_input type is expected
230230
* @param [in] input_shape Shape of @ref input_in
231231
* @param [in] input_shape_ndim Number of elements in @ref input_shape
232232
* @param [in] axes Array of axes to apply to @ref input_shape
@@ -259,14 +259,31 @@ INP_DLLEXPORT void dpnp_place_c(void* arr, long* mask, void* vals, const size_t
259259

260260
/**
261261
* @ingroup BACKEND_API
262-
* @brief Product of array elements
262+
* @brief Compute Product of input array elements.
263263
*
264-
* @param [in] array Input array.
265-
* @param [in] size Number of input elements in `array`.
266-
* @param [out] result Output array contains one element.
264+
* Input array is expected as @ref _DataType_input type and assume result as @ref _DataType_output type.
265+
* The function creates no memory.
266+
*
267+
* Empty @ref input_shape means scalar.
268+
*
269+
* @param [out] result_out Output array pointer. @ref _DataType_output type is expected
270+
* @param [in] input_in Input array pointer. @ref _DataType_input type is expected
271+
* @param [in] input_shape Shape of @ref input_in
272+
* @param [in] input_shape_ndim Number of elements in @ref input_shape
273+
* @param [in] axes Array of axes to apply to @ref input_shape
274+
* @param [in] axes_ndim Number of elements in @ref axes
275+
* @param [in] initial Pointer to initial value for the algorithm. @ref _DataType_input is expected
276+
* @param [in] where mask array
267277
*/
268-
template <typename _DataType>
269-
INP_DLLEXPORT void dpnp_prod_c(void* array, void* result, size_t size);
278+
template <typename _DataType_input, typename _DataType_output>
279+
INP_DLLEXPORT void dpnp_prod_c(void* result_out,
280+
const void* input_in,
281+
const size_t* input_shape,
282+
const size_t input_shape_ndim,
283+
const long* axes,
284+
const size_t axes_ndim,
285+
const void* initial,
286+
const long* where);
270287

271288
/**
272289
* @ingroup BACKEND_API

dpnp/backend/kernels/dpnp_krnl_reduction.cpp

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ _DataType* get_array_ptr(const void* __array)
4444
}
4545

4646
template <typename _DataType>
47-
_DataType get_initial_value(const void* __initial)
47+
_DataType get_initial_value(const void* __initial, _DataType default_val)
4848
{
4949
const _DataType* initial_ptr = reinterpret_cast<const _DataType*>(__initial);
50-
const _DataType init_val = (initial_ptr == nullptr) ? _DataType{0} : *initial_ptr;
50+
const _DataType init_val = (initial_ptr == nullptr) ? default_val : *initial_ptr;
5151

5252
return init_val;
5353
}
@@ -72,7 +72,7 @@ void dpnp_sum_c(void* result_out,
7272
return;
7373
}
7474

75-
const _DataType_output init = get_initial_value<_DataType_output>(initial);
75+
const _DataType_output init = get_initial_value<_DataType_output>(initial, 0);
7676

7777
_DataType_input* input = get_array_ptr<_DataType_input>(input_in);
7878
_DataType_output* result = get_array_ptr<_DataType_output>(result_out);
@@ -114,47 +114,90 @@ void dpnp_sum_c(void* result_out,
114114
// type of "init" determine internal algorithm accumulator type
115115
_DataType_output accumulator = std::reduce(
116116
policy, input_it.begin(output_id), input_it.end(output_id), init, std::plus<_DataType_output>());
117-
policy.queue().wait();
117+
policy.queue().wait(); // TODO move out of the loop
118118

119119
result[output_id] = accumulator;
120120
}
121121

122122
return;
123123
}
124124

125-
template <typename _KernelNameSpecialization>
125+
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2>
126126
class dpnp_prod_c_kernel;
127127

128-
template <typename _DataType>
129-
void dpnp_prod_c(void* array1_in, void* result1, size_t size)
128+
template <typename _DataType_input, typename _DataType_output>
129+
void dpnp_prod_c(void* result_out,
130+
const void* input_in,
131+
const size_t* input_shape,
132+
const size_t input_shape_ndim,
133+
const long* axes,
134+
const size_t axes_ndim,
135+
const void* initial, // type must be _DataType_output
136+
const long* where)
130137
{
131-
if (!size)
138+
(void)where; // avoid warning unused variable
139+
140+
if ((input_in == nullptr) || (result_out == nullptr))
132141
{
133142
return;
134143
}
135144

136-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
137-
_DataType* result = reinterpret_cast<_DataType*>(result1);
145+
const _DataType_output init = get_initial_value<_DataType_output>(initial, 1);
146+
147+
_DataType_input* input = get_array_ptr<_DataType_input>(input_in);
148+
_DataType_output* result = get_array_ptr<_DataType_output>(result_out);
149+
150+
if (!input_shape && !input_shape_ndim)
151+
{ // it is a scalar
152+
result[0] = input[0];
153+
154+
return;
155+
}
138156

139-
auto policy = oneapi::dpl::execution::make_device_policy<dpnp_prod_c_kernel<_DataType>>(DPNP_QUEUE);
157+
DPNPC_id<_DataType_input> input_it(input, input_shape, input_shape_ndim);
158+
input_it.set_axes(axes, axes_ndim);
140159

141-
result[0] = std::reduce(policy, array_1, array_1 + size, _DataType(1), std::multiplies<_DataType>());
160+
const size_t output_size = input_it.get_output_size();
161+
auto policy =
162+
oneapi::dpl::execution::make_device_policy<dpnp_prod_c_kernel<_DataType_input, _DataType_output>>(DPNP_QUEUE);
163+
for (size_t output_id = 0; output_id < output_size; ++output_id)
164+
{
165+
// type of "init" determine internal algorithm accumulator type
166+
_DataType_output accumulator = std::reduce(
167+
policy, input_it.begin(output_id), input_it.end(output_id), init, std::multiplies<_DataType_output>());
168+
policy.queue().wait(); // TODO move out of the loop
142169

143-
policy.queue().wait();
170+
result[output_id] = accumulator;
171+
}
144172

145173
return;
146174
}
147175

148176
void func_map_init_reduction(func_map_t& fmap)
149177
{
150-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_prod_c<int>};
151-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<long>};
152-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<float>};
153-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<double>};
154-
155178
// WARNING. The meaning of the fmap is changed. Second argument represents RESULT_TYPE for this function
156179
// handle "out" and "type" parameters require user selection of return type
157180
// TODO. required refactoring of fmap to some kernelSelector
181+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_INT] = {eft_LNG, (void*)dpnp_prod_c<int, int>};
182+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<int, long>};
183+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<int, float>};
184+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<int, double>};
185+
186+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_INT] = {eft_INT, (void*)dpnp_prod_c<long, int>};
187+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<long, long>};
188+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<long, float>};
189+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<long, double>};
190+
191+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_INT] = {eft_INT, (void*)dpnp_prod_c<float, int>};
192+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<float, long>};
193+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<float, float>};
194+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<float, double>};
195+
196+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_prod_c<double, int>};
197+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<double, long>};
198+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<double, float>};
199+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<double, double>};
200+
158201
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_INT] = {eft_LNG, (void*)dpnp_sum_c<int, int>};
159202
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_sum_c<int, long>};
160203
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_sum_c<int, float>};

dpnp/dparray.pyx

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,14 @@ from libcpp cimport bool as cpp_bool
3838

3939
from dpnp.dpnp_iface_types import *
4040
from dpnp.dpnp_iface import *
41+
42+
# to avoid interference with Python internal functions
43+
from dpnp.dpnp_iface import sum as iface_sum
44+
from dpnp.dpnp_iface import prod as iface_prod
45+
4146
from dpnp.dpnp_algo cimport *
42-
from dpnp.dpnp_iface_statistics import min, max
43-
from dpnp.dpnp_iface_logic import all, any
47+
from dpnp.dpnp_iface_statistics import min, max #TODO do the same as for iface_sum
48+
from dpnp.dpnp_iface_logic import all, any #TODO do the same as for iface_sum
4449
import numpy
4550
cimport numpy
4651

@@ -829,6 +834,18 @@ cdef class dparray:
829834
-------------------------------------------------------------------------
830835
"""
831836
837+
def prod(*args, **kwargs):
838+
"""
839+
Returns the prod along a given axis.
840+
841+
.. seealso::
842+
:obj:`dpnp.prod` for full documentation,
843+
:meth:`dpnp.dparray.sum`
844+
845+
"""
846+
847+
return iface_prod(*args, **kwargs)
848+
832849
def sum(*args, **kwargs):
833850
"""
834851
Returns the sum along a given axis.
@@ -839,9 +856,7 @@ cdef class dparray:
839856

840857
"""
841858
842-
# TODO don't know how to call `sum from python public interface. Simple call executes internal `sum` function.
843-
# numpy with dparray call public dpnp.sum via __array_interface__`
844-
return numpy.sum(*args, **kwargs)
859+
return iface_sum(*args, **kwargs)
845860
846861
def max(self, axis=None):
847862
"""

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ cpdef dparray dpnp_power(dparray x1, x2):
374374
return call_fptr_2in_1out(DPNP_FN_POWER, x1, x2, x1.shape)
375375

376376

377-
cpdef dpnp_prod(dparray x1):
377+
cpdef dparray dpnp_prod(dparray input, object axis=None, object dtype=None, dparray out=None, cpp_bool keepdims=False, object initial=None, object where=True):
378378
"""
379379
input:float64 : outout:float64 : name:prod
380380
input:float32 : outout:float32 : name:prod
@@ -385,12 +385,25 @@ cpdef dpnp_prod(dparray x1):
385385
input:complex128: outout:complex128: name:prod
386386
"""
387387

388-
cdef dparray result = call_fptr_1in_1out(DPNP_FN_PROD, x1, (1,))
388+
cdef dparray_shape_type input_shape = input.shape
389+
cdef DPNPFuncType input_c_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
390+
391+
cdef dparray_shape_type axis_shape = _object_to_tuple(axis)
389392

390-
""" Numpy interface inconsistency """
391-
return_type = numpy.dtype(numpy.int64) if (x1.dtype == numpy.int32) else x1.dtype
393+
cdef dparray_shape_type result_shape = get_reduction_output_shape(input_shape, axis, keepdims)
394+
cdef DPNPFuncType result_c_type = get_output_c_type(DPNP_FN_PROD, input_c_type, out, dtype)
395+
396+
""" select kernel """
397+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PROD, input_c_type, result_c_type)
392398

393-
return return_type.type(result[0])
399+
""" Create result array """
400+
cdef dparray result = create_output_array(result_shape, result_c_type, out)
401+
cdef dpnp_reduction_c_t func = <dpnp_reduction_c_t > kernel_data.ptr
402+
403+
""" Call FPTR interface function """
404+
func(result.get_data(), input.get_data(), < size_t * >input_shape.data(), input_shape.size(), axis_shape.data(), axis_shape.size(), NULL, NULL)
405+
406+
return result
394407

395408

396409
cpdef dparray dpnp_remainder(dparray x1, dparray x2):

dpnp/dpnp_iface_mathematical.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@
9393
]
9494

9595

96+
def convert_result_scalar(result, keepdims):
97+
# one element array result should be converted into scalar
98+
# TODO empty shape must be converted into scalar (it is not in test system)
99+
if (len(result.shape) > 0) and (result.size == 1) and (keepdims is False):
100+
return result.dtype.type(result[0])
101+
else:
102+
return result
103+
104+
96105
def abs(*args, **kwargs):
97106
"""
98107
Calculate the absolute value element-wise.
@@ -1235,18 +1244,17 @@ def power(x1, x2, **kwargs):
12351244
return call_origin(numpy.power, x1, x2, **kwargs)
12361245

12371246

1238-
def prod(x1, **kwargs):
1247+
def prod(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=True):
12391248
"""
12401249
Calculate product of array elements over a given axis.
12411250
12421251
For full documentation refer to :obj:`numpy.prod`.
12431252
12441253
Limitations
12451254
-----------
1246-
Parameter ``x1`` is supported as :obj:`dpnp.ndarray`.
1247-
Keyword arguments ``kwargs`` are currently unsupported.
1248-
Otherwise the functions will be executed sequentially on CPU.
1249-
Input array data types are limited by supported DPNP :ref:`Data types`.
1255+
Parameter ``x1`` is supported as :obj:`dpnp.dparray` only.
1256+
Parameter ``where`` is unsupported.
1257+
Input array data types are limited by DPNP :ref:`Data types`.
12501258
12511259
Examples
12521260
--------
@@ -1258,12 +1266,18 @@ def prod(x1, **kwargs):
12581266
12591267
"""
12601268

1261-
is_x1_dparray = isinstance(x1, dparray)
1262-
1263-
if (not use_origin_backend(x1) and is_x1_dparray and not kwargs):
1264-
return dpnp_prod(x1)
1269+
if not use_origin_backend(x1):
1270+
if not isinstance(x1, dparray):
1271+
pass
1272+
elif out is not None and not isinstance(out, dparray):
1273+
pass
1274+
elif where is not True:
1275+
pass
1276+
else:
1277+
result = dpnp_prod(x1, axis, dtype, out, keepdims, initial, where)
1278+
return convert_result_scalar(result, keepdims)
12651279

1266-
return call_origin(numpy.prod, x1, **kwargs)
1280+
return call_origin(numpy.prod, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
12671281

12681282

12691283
def remainder(x1, x2, **kwargs):
@@ -1403,7 +1417,7 @@ def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where
14031417
Limitations
14041418
-----------
14051419
Parameter ``x1`` is supported as :obj:`dpnp.dparray` only.
1406-
Parameters ``initial`` and ``where`` from keyword arguments ``kwargs`` are unsupported.
1420+
Parameter `where`` is unsupported.
14071421
Input array data types are limited by DPNP :ref:`Data types`.
14081422
14091423
Examples
@@ -1425,13 +1439,7 @@ def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where
14251439
pass
14261440
else:
14271441
result = dpnp_sum(x1, axis, dtype, out, keepdims, initial, where)
1428-
1429-
# one element array result should be converted into scalar
1430-
# TODO empty shape must be converted into scalar (it is not in test system)
1431-
if (len(result.shape) > 0) and (result.size == 1) and (keepdims is False):
1432-
return result.dtype.type(result[0])
1433-
1434-
return result
1442+
return convert_result_scalar(result, keepdims)
14351443

14361444
return call_origin(numpy.sum, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
14371445

tests/skipped_tests.tbl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,9 +1116,6 @@ tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_1
11161116
tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_15_{axis=0, func='nanprod', keepdims=False, shape=(20, 30, 40), transpose_axes=False}::test_nansum_axis_transposed
11171117
tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_9_{axis=0, func='nanprod', keepdims=True, shape=(2, 3, 4), transpose_axes=False}::test_nansum_all
11181118
tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_9_{axis=0, func='nanprod', keepdims=True, shape=(2, 3, 4), transpose_axes=False}::test_nansum_axis_transposed
1119-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_prod_all
1120-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_prod_axis
1121-
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_prod_dtype
11221119
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all2
11231120
tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed2
11241121
tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_1dim

0 commit comments

Comments
 (0)