Skip to content

Commit 3ebd8cb

Browse files
authored
Swap output and input parameter types in sum() and prod() (#639)
1 parent 74c8670 commit 3ebd8cb

File tree

4 files changed

+39
-39
lines changed

4 files changed

+39
-39
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ INP_DLLEXPORT void dpnp_cumsum_c(void* array1_in, void* result1, size_t size);
256256
* @param [in] initial Pointer to initial value for the algorithm. @ref _DataType_input is expected
257257
* @param [in] where mask array
258258
*/
259-
template <typename _DataType_input, typename _DataType_output>
259+
template <typename _DataType_output, typename _DataType_input>
260260
INP_DLLEXPORT void dpnp_sum_c(void* result_out,
261261
const void* input_in,
262262
const size_t* input_shape,
@@ -297,15 +297,15 @@ INP_DLLEXPORT void dpnp_place_c(void* arr, long* mask, void* vals, const size_t
297297
* @param [in] initial Pointer to initial value for the algorithm. @ref _DataType_input is expected
298298
* @param [in] where mask array
299299
*/
300-
template <typename _DataType_input, typename _DataType_output>
300+
template <typename _DataType_output, typename _DataType_input>
301301
INP_DLLEXPORT void dpnp_prod_c(void* result_out,
302-
const void* input_in,
303-
const size_t* input_shape,
304-
const size_t input_shape_ndim,
305-
const long* axes,
306-
const size_t axes_ndim,
307-
const void* initial,
308-
const long* where);
302+
const void* input_in,
303+
const size_t* input_shape,
304+
const size_t input_shape_ndim,
305+
const long* axes,
306+
const size_t axes_ndim,
307+
const void* initial,
308+
const long* where);
309309

310310
/**
311311
* @ingroup BACKEND_API

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ void dpnp_trapz_c(
335335
}
336336
else
337337
{
338-
dpnp_sum_c<_DataType_input1, _DataType_output>(result, array1, &array1_size, 1, NULL, 0, NULL, NULL);
338+
dpnp_sum_c<_DataType_output, _DataType_input1>(result, array1, &array1_size, 1, NULL, 0, NULL, NULL);
339339

340340
result[0] -= (array1[0] + array1[array1_size - 1]) * 0.5;
341341
result[0] *= dx;

dpnp/backend/kernels/dpnp_krnl_reduction.cpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ _DataType get_initial_value(const void* __initial, _DataType default_val)
5555
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2>
5656
class dpnp_sum_c_kernel;
5757

58-
template <typename _DataType_input, typename _DataType_output>
58+
template <typename _DataType_output, typename _DataType_input>
5959
void dpnp_sum_c(void* result_out,
6060
const void* input_in,
6161
const size_t* input_shape,
@@ -108,7 +108,7 @@ void dpnp_sum_c(void* result_out,
108108

109109
const size_t output_size = input_it.get_output_size();
110110
auto policy =
111-
oneapi::dpl::execution::make_device_policy<dpnp_sum_c_kernel<_DataType_input, _DataType_output>>(DPNP_QUEUE);
111+
oneapi::dpl::execution::make_device_policy<dpnp_sum_c_kernel<_DataType_output, _DataType_input>>(DPNP_QUEUE);
112112
for (size_t output_id = 0; output_id < output_size; ++output_id)
113113
{
114114
// type of "init" determine internal algorithm accumulator type
@@ -125,7 +125,7 @@ void dpnp_sum_c(void* result_out,
125125
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2>
126126
class dpnp_prod_c_kernel;
127127

128-
template <typename _DataType_input, typename _DataType_output>
128+
template <typename _DataType_output, typename _DataType_input>
129129
void dpnp_prod_c(void* result_out,
130130
const void* input_in,
131131
const size_t* input_shape,
@@ -159,7 +159,7 @@ void dpnp_prod_c(void* result_out,
159159

160160
const size_t output_size = input_it.get_output_size();
161161
auto policy =
162-
oneapi::dpl::execution::make_device_policy<dpnp_prod_c_kernel<_DataType_input, _DataType_output>>(DPNP_QUEUE);
162+
oneapi::dpl::execution::make_device_policy<dpnp_prod_c_kernel<_DataType_output, _DataType_input>>(DPNP_QUEUE);
163163
for (size_t output_id = 0; output_id < output_size; ++output_id)
164164
{
165165
// type of "init" determine internal algorithm accumulator type
@@ -179,43 +179,43 @@ void func_map_init_reduction(func_map_t& fmap)
179179
// handle "out" and "type" parameters require user selection of return type
180180
// TODO. required refactoring of fmap to some kernelSelector
181181
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_INT] = {eft_LNG, (void*)dpnp_prod_c<int, int>};
182-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<int, long>};
183-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<int, float>};
184-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<int, double>};
182+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<long, int>};
183+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<float, int>};
184+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<double, int>};
185185

186-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_INT] = {eft_INT, (void*)dpnp_prod_c<long, int>};
186+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_INT] = {eft_INT, (void*)dpnp_prod_c<int, long>};
187187
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<long, long>};
188-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<long, float>};
189-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<long, double>};
188+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<float, long>};
189+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<double, long>};
190190

191-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_INT] = {eft_INT, (void*)dpnp_prod_c<float, int>};
192-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<float, long>};
191+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_INT] = {eft_INT, (void*)dpnp_prod_c<int, float>};
192+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<long, float>};
193193
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<float, float>};
194-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<float, double>};
194+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<double, float>};
195195

196-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_prod_c<double, int>};
197-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<double, long>};
198-
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<double, float>};
196+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_prod_c<int, double>};
197+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_prod_c<long, double>};
198+
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_FLT] = {eft_FLT, (void*)dpnp_prod_c<float, double>};
199199
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_prod_c<double, double>};
200200

201201
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_INT] = {eft_LNG, (void*)dpnp_sum_c<int, int>};
202-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_sum_c<int, long>};
203-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_sum_c<int, float>};
204-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_sum_c<int, double>};
202+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_sum_c<long, int>};
203+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_sum_c<float, int>};
204+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_sum_c<double, int>};
205205

206-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_LNG][eft_INT] = {eft_INT, (void*)dpnp_sum_c<long, int>};
206+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_LNG][eft_INT] = {eft_INT, (void*)dpnp_sum_c<int, long>};
207207
fmap[DPNPFuncName::DPNP_FN_SUM][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_sum_c<long, long>};
208-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_sum_c<long, float>};
209-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_sum_c<long, double>};
208+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_sum_c<float, long>};
209+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_sum_c<double, long>};
210210

211-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_FLT][eft_INT] = {eft_INT, (void*)dpnp_sum_c<float, int>};
212-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_FLT][eft_LNG] = {eft_LNG, (void*)dpnp_sum_c<float, long>};
211+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_FLT][eft_INT] = {eft_INT, (void*)dpnp_sum_c<int, float>};
212+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_FLT][eft_LNG] = {eft_LNG, (void*)dpnp_sum_c<long, float>};
213213
fmap[DPNPFuncName::DPNP_FN_SUM][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_sum_c<float, float>};
214-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_sum_c<float, double>};
214+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_sum_c<double, float>};
215215

216-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_sum_c<double, int>};
217-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_sum_c<double, long>};
218-
fmap[DPNPFuncName::DPNP_FN_SUM][eft_DBL][eft_FLT] = {eft_FLT, (void*)dpnp_sum_c<double, float>};
216+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_sum_c<int, double>};
217+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_sum_c<long, double>};
218+
fmap[DPNPFuncName::DPNP_FN_SUM][eft_DBL][eft_FLT] = {eft_FLT, (void*)dpnp_sum_c<float, double>};
219219
fmap[DPNPFuncName::DPNP_FN_SUM][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_sum_c<double, double>};
220220

221221
return;

dpnp/backend/kernels/dpnp_krnl_statistics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ void dpnp_mean_c(void* array1_in, void* result1, const size_t* shape, size_t ndi
195195
{
196196
_ResultType* sum = reinterpret_cast<_ResultType*>(dpnp_memory_alloc_c(1 * sizeof(_ResultType)));
197197

198-
dpnp_sum_c<_DataType, _ResultType>(
198+
dpnp_sum_c<_ResultType, _DataType>(
199199
sum, array1_in, shape, ndim, reinterpret_cast<const long*>(axis), naxis, nullptr, nullptr);
200200

201201
result[0] = sum[0] / static_cast<_ResultType>(size);

0 commit comments

Comments
 (0)