Skip to content

Commit f6dd61a

Browse files
authored
add broadcasting for remainder func (#714)
* add broadcasting for remainder
1 parent c79ad0c commit f6dd61a

File tree

3 files changed

+96
-38
lines changed

3 files changed

+96
-38
lines changed

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -346,50 +346,80 @@ void dpnp_remainder_c(void* result_out,
346346
const size_t input2_shape_ndim,
347347
const size_t* where)
348348
{
349-
(void)input1_shape;
350-
(void)input1_shape_ndim;
351-
(void)input2_size;
352-
(void)input2_shape;
353-
(void)input2_shape_ndim;
354349
(void)where;
355350

356-
cl::sycl::event event;
357-
_DataType_input1* input1 = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in));
358-
_DataType_input2* input2 = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in));
351+
if (!input1_size || !input2_size)
352+
{
353+
return;
354+
}
355+
356+
_DataType_input1* input1_data = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in));
357+
_DataType_input2* input2_data = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in));
359358
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
360359

361-
if constexpr ((std::is_same<_DataType_input1, double>::value || std::is_same<_DataType_input1, float>::value) &&
362-
std::is_same<_DataType_input2, _DataType_input1>::value)
360+
std::vector<size_t> result_shape = get_result_shape(input1_shape, input1_shape_ndim,
361+
input2_shape, input2_shape_ndim);
362+
363+
DPNPC_id<_DataType_input1>* input1_it;
364+
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>);
365+
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(input1_it_size_in_bytes));
366+
new (input1_it) DPNPC_id<_DataType_input1>(input1_data, input1_shape, input1_shape_ndim);
367+
368+
input1_it->broadcast_to_shape(result_shape);
369+
370+
DPNPC_id<_DataType_input2>* input2_it;
371+
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>);
372+
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(input2_it_size_in_bytes));
373+
new (input2_it) DPNPC_id<_DataType_input2>(input2_data, input2_shape, input2_shape_ndim);
374+
375+
input2_it->broadcast_to_shape(result_shape);
376+
377+
const size_t result_size = input1_it->get_output_size();
378+
379+
cl::sycl::range<1> gws(result_size);
380+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
381+
const size_t i = global_id[0];
382+
const _DataType_output input1_elem = (*input1_it)[i];
383+
const _DataType_output input2_elem = (*input2_it)[i];
384+
double fmod_res = cl::sycl::fmod((double)input1_elem, (double)input2_elem);
385+
double add = fmod_res + input2_elem;
386+
result[i] = cl::sycl::fmod(add, (double)input2_elem);
387+
388+
};
389+
auto kernel_func = [&](cl::sycl::handler& cgh) {
390+
cgh.parallel_for<class dpnp_remainder_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
391+
gws, kernel_parallel_for_func);
392+
};
393+
394+
cl::sycl::event event;
395+
396+
if (input1_size == input2_size)
363397
{
364-
event = oneapi::mkl::vm::fmod(DPNP_QUEUE, input1_size, input1, input2, result);
365-
event.wait();
366-
event = oneapi::mkl::vm::add(DPNP_QUEUE, input1_size, result, input2, result);
367-
event.wait();
368-
event = oneapi::mkl::vm::fmod(DPNP_QUEUE, input1_size, result, input2, result);
398+
if constexpr ((std::is_same<_DataType_input1, double>::value ||
399+
std::is_same<_DataType_input1, float>::value) &&
400+
std::is_same<_DataType_input2, _DataType_input1>::value)
401+
{
402+
event = oneapi::mkl::vm::fmod(DPNP_QUEUE, input1_size, input1_data, input2_data, result);
403+
event.wait();
404+
event = oneapi::mkl::vm::add(DPNP_QUEUE, input1_size, result, input2_data, result);
405+
event.wait();
406+
event = oneapi::mkl::vm::fmod(DPNP_QUEUE, input1_size, result, input2_data, result);
407+
}
408+
else
409+
{
410+
event = DPNP_QUEUE.submit(kernel_func);
411+
}
369412
}
370413
else
371414
{
372-
cl::sycl::range<1> gws(input1_size);
373-
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
374-
size_t i = global_id[0]; /*for (size_t i = 0; i < size; ++i)*/
375-
{
376-
_DataType_input1 input_elem1 = input1[i];
377-
_DataType_input2 input_elem2 = input2[i];
378-
double fmod = cl::sycl::fmod((double)input_elem1, (double)input_elem2);
379-
double add = fmod + input_elem2;
380-
result[i] = cl::sycl::fmod(add, (double)input_elem2);
381-
}
382-
};
383-
384-
auto kernel_func = [&](cl::sycl::handler& cgh) {
385-
cgh.parallel_for<class dpnp_remainder_c_kernel<_DataType_input1, _DataType_input2, _DataType_output>>(
386-
gws, kernel_parallel_for_func);
387-
};
388-
389415
event = DPNP_QUEUE.submit(kernel_func);
390416
}
391417

392418
event.wait();
419+
420+
input1_it->~DPNPC_id();
421+
input2_it->~DPNPC_id();
422+
393423
}
394424

395425
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>

dpnp/dpnp_iface_mathematical.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,15 +1385,16 @@ def prod(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, wher
13851385
return call_origin(numpy.prod, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
13861386

13871387

1388-
def remainder(x1, x2, **kwargs):
1388+
def remainder(x1, x2, out=None, where=True, dtype=None, **kwargs):
13891389
"""
13901390
Return element-wise remainder of division.
13911391
13921392
For full documentation refer to :obj:`numpy.remainder`.
13931393
13941394
Limitations
13951395
-----------
1396-
Parameters ``x1`` and ``x2`` are supported as :obj:`dpnp.ndarray`.
1396+
Parameters ``x1`` and ``x2`` are supported as either :obj:`dpnp.ndarray` or scalar.
1397+
Parameters ``dtype``, ``out`` and ``where`` are supported with their default values.
13971398
Keyword arguments ``kwargs`` are currently unsupported.
13981399
Otherwise the functions will be executed sequentially on CPU.
13991400
Input array data types are limited by supported DPNP :ref:`Data types`.
@@ -1414,18 +1415,42 @@ def remainder(x1, x2, **kwargs):
14141415
14151416
"""
14161417

1418+
x1_is_scalar = dpnp.isscalar(x1)
1419+
x2_is_scalar = dpnp.isscalar(x2)
14171420
x1_desc = dpnp.get_dpnp_descriptor(x1)
14181421
x2_desc = dpnp.get_dpnp_descriptor(x2)
14191422

14201423
if x1_desc and x2_desc and not kwargs:
1421-
if x1_desc.size != x2_desc.size:
1424+
if not x1_desc and not x1_is_scalar:
1425+
pass
1426+
elif not x2_desc and not x2_is_scalar:
14221427
pass
1423-
elif x1_desc.shape != x2_desc.shape:
1428+
elif x1_is_scalar and x2_is_scalar:
1429+
pass
1430+
elif x1_desc and x1_desc.ndim == 0:
1431+
pass
1432+
elif x2_desc and x2_desc.ndim == 0:
1433+
pass
1434+
elif x2_is_scalar and not x2_desc:
1435+
pass
1436+
elif x1_desc and x2_desc and x1_desc.size != x2_desc.size:
1437+
pass
1438+
elif x1_desc and x2_desc and x1_desc.shape != x2_desc.shape:
1439+
pass
1440+
elif out is not None and not isinstance(out, dparray):
1441+
pass
1442+
elif dtype is not None:
1443+
pass
1444+
elif out is not None:
1445+
pass
1446+
elif not where:
1447+
pass
1448+
elif x1_is_scalar and x2_desc.ndim > 1:
14241449
pass
14251450
else:
1426-
return dpnp_remainder(x1_desc, x2_desc).get_pyobj()
1451+
return dpnp_remainder(x1_desc, x2_desc, out=out, where=where, dtype=dtype)
14271452

1428-
return call_origin(numpy.remainder, x1, x2, **kwargs)
1453+
return call_origin(numpy.remainder, x1, x2, out=out, where=where, dtype=dtype, **kwargs)
14291454

14301455

14311456
def round_(a, decimals=0, out=None):

tests/test_mathematical.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def test_minimum(self, dtype, lhs, rhs):
148148
def test_multiply(self, dtype, lhs, rhs):
149149
self._test_mathematical('multiply', dtype, lhs, rhs)
150150

151+
def test_remainder(self, dtype, lhs, rhs):
152+
self._test_mathematical('remainder', dtype, lhs, rhs)
153+
151154
def test_power(self, dtype, lhs, rhs):
152155
self._test_mathematical('power', dtype, lhs, rhs)
153156

0 commit comments

Comments
 (0)