Skip to content

Commit 211404a

Browse files
Take func kernel (#579)
* take func kernel impl
1 parent c366fb7 commit 211404a

File tree

6 files changed

+35
-32
lines changed

6 files changed

+35
-32
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,11 +457,11 @@ INP_DLLEXPORT void dpnp_std_c(
457457
* @brief math library implementation of take function
458458
*
459459
* @param [in] array Input array with data.
460-
* @param [in] array Input array with indices.
461-
* @param [out] result Output array with indeces.
460+
* @param [in] indices Input array with indices.
461+
* @param [out] result Output array.
462462
* @param [in] size Number of elements in the input array.
463463
*/
464-
template <typename _DataType>
464+
template <typename _DataType, typename _IndecesType>
465465
INP_DLLEXPORT void dpnp_take_c(void* array, void* indices, void* result, size_t size);
466466

467467
/**

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,39 +23,45 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26-
#include <iostream>
27-
#include <list>
28-
2926
#include <dpnp_iface.hpp>
3027
#include "dpnp_fptr.hpp"
31-
#include "dpnp_utils.hpp"
3228
#include "queue_sycl.hpp"
3329

34-
template <typename _DataType>
30+
template <typename _DataType, typename _IndecesType>
3531
class dpnp_take_c_kernel;
3632

37-
template <typename _DataType>
33+
template <typename _DataType, typename _IndecesType>
3834
void dpnp_take_c(void* array1_in, void* indices1, void* result1, size_t size)
3935
{
4036
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
4137
_DataType* result = reinterpret_cast<_DataType*>(result1);
42-
size_t* indices = reinterpret_cast<size_t*>(indices1);
38+
_IndecesType* indices = reinterpret_cast<_IndecesType*>(indices1);
39+
40+
cl::sycl::range<1> gws(size);
41+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
42+
const size_t idx = global_id[0];
43+
result[idx] = array_1[indices[idx]];
44+
};
45+
46+
auto kernel_func = [&](cl::sycl::handler& cgh) {
47+
cgh.parallel_for<class dpnp_take_c_kernel<_DataType, _IndecesType>>(gws, kernel_parallel_for_func);
48+
};
49+
50+
cl::sycl::event event = DPNP_QUEUE.submit(kernel_func);
4351

44-
for (size_t i = 0; i < size; i++)
45-
{
46-
size_t ind = indices[i];
47-
result[i] = array_1[ind];
48-
}
52+
event.wait();
4953

5054
return;
5155
}
5256

5357
void func_map_init_indexing_func(func_map_t& fmap)
5458
{
55-
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_take_c<int>};
56-
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_take_c<long>};
57-
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_take_c<float>};
58-
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_take_c<double>};
59+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_BOOL][eft_BOOL] = {eft_BOOL, (void*)dpnp_take_c<bool, long>};
60+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_take_c<int, long>};
61+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_take_c<long, long>};
62+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_take_c<float, long>};
63+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_take_c<double, long>};
64+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_take_c<std::complex<double>, long>};
5965

6066
return;
6167
}

dpnp/dparray.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,13 @@ cdef class dparray:
863863
"""
864864
return choose(input, choices, out, mode)
865865
866+
def take(self, indices, axis=None, out=None, mode='raise'):
867+
"""
868+
Take elements from an array.
869+
For full documentation refer to :obj:`numpy.take`.
870+
"""
871+
return take(self, indices, axis, out, mode)
872+
866873
"""
867874
-------------------------------------------------------------------------
868875
Sorting

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,20 +261,18 @@ cpdef dparray dpnp_select(condlist, choicelist, default):
261261

262262

263263
cpdef dparray dpnp_take(dparray input, dparray indices):
264-
indices_size = indices.size
265-
266264
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
267265

268266
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TAKE, param1_type, param1_type)
269267

270268
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
271-
cdef dparray result = dparray(indices_size, dtype=result_type)
269+
cdef dparray result = dparray(indices.shape, dtype=result_type)
272270

273271
cdef custom_indexing_2in_1out_func_ptr_t func = <custom_indexing_2in_1out_func_ptr_t > kernel_data.ptr
274272

275-
func(input.get_data(), indices.get_data(), result.get_data(), indices_size)
273+
func(input.get_data(), indices.get_data(), result.get_data(), indices.size)
276274

277-
return result.reshape(indices.shape)
275+
return result
278276

279277

280278
cpdef tuple dpnp_tril_indices(n, k=0, m=None):

tests/skipped_tests.tbl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diago
569569
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative4
570570
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative5
571571
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_external_diagonal
572-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_external_take_by_scalar
573572
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract
574573
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_empty_1dim
575574
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_no_bool
@@ -578,10 +577,7 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extra
578577
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch2
579578
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_along_axis
580579
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_along_axis_none_axis
581-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_by_array
582-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_by_scalar
583580
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
584-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_no_axis
585581
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
586582
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
587583
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diago
908908
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative4
909909
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative5
910910
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_external_diagonal
911-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_external_take_by_scalar
912911
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract
913912
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_empty_1dim
914913
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_no_bool
@@ -917,10 +916,7 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extra
917916
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch2
918917
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_along_axis
919918
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_along_axis_none_axis
920-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_by_array
921-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_by_scalar
922919
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
923-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_no_axis
924920
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
925921
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
926922
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast

0 commit comments

Comments
 (0)