Skip to content

Commit 6326cb7

Browse files
authored
add broadcasting for floor_divide func (#702)
* add broadcasting for floor_divide
1 parent cceb7cf commit 6326cb7

File tree

3 files changed

+100
-37
lines changed

3 files changed

+100
-37
lines changed

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
#include <dpnp_iface.hpp>
3131
#include "dpnp_fptr.hpp"
32+
#include "dpnp_iterator.hpp"
3233
#include "dpnp_utils.hpp"
3334
#include "queue_sycl.hpp"
3435

@@ -226,47 +227,78 @@ void dpnp_floor_divide_c(void* result_out,
226227
const size_t input2_shape_ndim,
227228
const size_t* where)
228229
{
229-
(void)input1_shape;
230-
(void)input1_shape_ndim;
231-
(void)input2_size;
232-
(void)input2_shape;
233-
(void)input2_shape_ndim;
234230
(void)where;
235231

236-
cl::sycl::event event;
237-
_DataType_input1* input1 = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in));
238-
_DataType_input2* input2 = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in));
232+
if (!input1_size || !input2_size)
233+
{
234+
return;
235+
}
236+
237+
_DataType_input1* input1_data = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in));
238+
_DataType_input2* input2_data = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in));
239239
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
240240

241-
if constexpr ((std::is_same<_DataType_input1, double>::value || std::is_same<_DataType_input1, float>::value) &&
242-
std::is_same<_DataType_input2, _DataType_input1>::value)
241+
std::vector<size_t> result_shape = get_result_shape(input1_shape, input1_shape_ndim,
242+
input2_shape, input2_shape_ndim);
243+
244+
DPNPC_id<_DataType_input1>* input1_it;
245+
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>);
246+
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(input1_it_size_in_bytes));
247+
new (input1_it) DPNPC_id<_DataType_input1>(input1_data, input1_shape, input1_shape_ndim);
248+
249+
input1_it->broadcast_to_shape(result_shape);
250+
251+
DPNPC_id<_DataType_input2>* input2_it;
252+
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>);
253+
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(input2_it_size_in_bytes));
254+
new (input2_it) DPNPC_id<_DataType_input2>(input2_data, input2_shape, input2_shape_ndim);
255+
256+
input2_it->broadcast_to_shape(result_shape);
257+
258+
const size_t result_size = input1_it->get_output_size();
259+
260+
261+
cl::sycl::range<1> gws(result_size);
262+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
263+
const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */
264+
const _DataType_output input1_elem = (*input1_it)[i];
265+
const _DataType_output input2_elem = (*input2_it)[i];
266+
267+
double div = (double)input1_elem / (double)input2_elem;
268+
result[i] = static_cast<_DataType_output>(cl::sycl::floor(div));
269+
};
270+
auto kernel_func = [&](cl::sycl::handler& cgh) {
271+
cgh.parallel_for<class dpnp_floor_divide_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
272+
gws, kernel_parallel_for_func);
273+
};
274+
275+
cl::sycl::event event;
276+
277+
if (input1_size == input2_size)
243278
{
244-
event = oneapi::mkl::vm::div(DPNP_QUEUE, input1_size, input1, input2, result);
245-
event.wait();
246-
event = oneapi::mkl::vm::floor(DPNP_QUEUE, input1_size, result, result);
279+
if constexpr ((std::is_same<_DataType_input1, double>::value ||
280+
std::is_same<_DataType_input1, float>::value) &&
281+
std::is_same<_DataType_input2, _DataType_input1>::value)
282+
{
283+
event = oneapi::mkl::vm::div(DPNP_QUEUE, input1_size, input1_data, input2_data, result);
284+
event.wait();
285+
event = oneapi::mkl::vm::floor(DPNP_QUEUE, input1_size, result, result);
286+
}
287+
else
288+
{
289+
event = DPNP_QUEUE.submit(kernel_func);
290+
}
247291
}
248292
else
249293
{
250-
cl::sycl::range<1> gws(input1_size);
251-
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
252-
size_t i = global_id[0]; /*for (size_t i = 0; i < size; ++i)*/
253-
{
254-
_DataType_input1 input_elem1 = input1[i];
255-
_DataType_input2 input_elem2 = input2[i];
256-
double div = (double)input_elem1 / (double)input_elem2;
257-
result[i] = static_cast<_DataType_output>(cl::sycl::floor(div));
258-
}
259-
};
260-
261-
auto kernel_func = [&](cl::sycl::handler& cgh) {
262-
cgh.parallel_for<class dpnp_floor_divide_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
263-
gws, kernel_parallel_for_func);
264-
};
265-
266294
event = DPNP_QUEUE.submit(kernel_func);
267295
}
268296

269297
event.wait();
298+
299+
input1_it->~DPNPC_id();
300+
input2_it->~DPNPC_id();
301+
270302
}
271303

272304
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2>

dpnp/dpnp_iface_mathematical.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -681,15 +681,16 @@ def floor(x1, **kwargs):
681681
return call_origin(numpy.floor, x1, **kwargs)
682682

683683

684-
def floor_divide(x1, x2, **kwargs):
684+
def floor_divide(x1, x2, dtype=None, out=None, where=True, **kwargs):
685685
"""
686686
Compute the largest integer smaller or equal to the division of the inputs.
687687
688688
For full documentation refer to :obj:`numpy.floor_divide`.
689689
690690
Limitations
691691
-----------
692-
Parameters ``x1`` and ``x2`` are supported as :obj:`dpnp.ndarray`.
692+
Parameters ``x1`` and ``x2`` are supported as either :obj:`dpnp.ndarray` or scalar.
693+
Parameters ``dtype``, ``out`` and ``where`` are supported with their default values.
693694
Keyword arguments ``kwargs`` are currently unsupported.
694695
Otherwise the functions will be executed sequentially on CPU.
695696
Input array data types are limited by supported DPNP :ref:`Data types`.
@@ -710,13 +711,40 @@ def floor_divide(x1, x2, **kwargs):
710711
711712
"""
712713

713-
is_x1_dparray = isinstance(x1, dparray)
714-
is_x2_dparray = isinstance(x2, dparray)
715-
716-
if not use_origin_backend(x1) and is_x1_dparray and is_x2_dparray and not kwargs:
717-
return dpnp_floor_divide(x1, x2)
714+
x1_is_scalar, x2_is_scalar = dpnp.isscalar(x1), dpnp.isscalar(x2)
715+
x1_is_dparray, x2_is_dparray = isinstance(x1, dparray), isinstance(x2, dparray)
716+
717+
if not use_origin_backend(x1) and not kwargs:
718+
if not x1_is_dparray and not x1_is_scalar:
719+
pass
720+
elif not x2_is_dparray and not x2_is_scalar:
721+
pass
722+
elif x1_is_scalar and x2_is_scalar:
723+
pass
724+
elif x1_is_dparray and x1.ndim == 0:
725+
pass
726+
elif x2_is_dparray and x2.ndim == 0:
727+
pass
728+
elif x2_is_scalar and x2 == 0:
729+
pass
730+
elif x1_is_dparray and x2_is_dparray and x1.size != x2.size:
731+
pass
732+
elif x1_is_dparray and x2_is_dparray and x1.shape != x2.shape:
733+
pass
734+
elif out is not None and not isinstance(out, dparray):
735+
pass
736+
elif dtype is not None:
737+
pass
738+
elif out is not None:
739+
pass
740+
elif not where:
741+
pass
742+
elif x1_is_scalar and x2.ndim > 1:
743+
pass
744+
else:
745+
return dpnp_floor_divide(x1, x2, out=out, where=where, dtype=dtype)
718746

719-
return call_origin(numpy.floor_divide, x1, x2, **kwargs)
747+
return call_origin(numpy.floor_divide, x1, x2, out=out, where=where, dtype=dtype, **kwargs)
720748

721749

722750
def fmax(*args, **kwargs):

tests/test_mathematical.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def test_divide(self, dtype, lhs, rhs):
105105
def test_fmod(self, dtype, lhs, rhs):
106106
self._test_mathematical('fmod', dtype, lhs, rhs)
107107

108+
def test_floor_divide(self, dtype, lhs, rhs):
109+
self._test_mathematical('floor_divide', dtype, lhs, rhs)
110+
108111
def test_hypot(self, dtype, lhs, rhs):
109112
self._test_mathematical('hypot', dtype, lhs, rhs)
110113

0 commit comments

Comments
 (0)