Skip to content

Commit 3e9c486

Browse files
authored
add native version for function tri (#623)
* add native version for function tri
1 parent 4a347d2 commit 3e9c486

File tree

7 files changed

+142
-35
lines changed

7 files changed

+142
-35
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,18 @@ INP_DLLEXPORT void dpnp_std_c(
630630
template <typename _DataType, typename _IndecesType>
631631
INP_DLLEXPORT void dpnp_take_c(void* array, void* indices, void* result, size_t size);
632632

633+
/**
634+
* @ingroup BACKEND_API
635+
* @brief math library implementation of take function
636+
*
637+
* @param [out] result Output array.
638+
* @param [in] N Number of rows in the array.
639+
* @param [in] M Number of columns in the array.
640+
* @param [in] k The sub-diagonal at and below which the array is filled.
641+
*/
642+
template <typename _DataType>
643+
INP_DLLEXPORT void dpnp_tri_c(void* result, const size_t N, const size_t M, const int k);
644+
633645
/**
634646
* @ingroup BACKEND_API
635647
* @brief math library implementation of take function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ enum class DPNPFuncName : size_t
195195
DPNP_FN_TANH, /**< Used in numpy.tanh() implementation */
196196
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() implementation */
197197
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() implementation */
198+
DPNP_FN_TRI, /**< Used in numpy.tri() implementation */
198199
DPNP_FN_TRIL, /**< Used in numpy.tril() implementation */
199200
DPNP_FN_TRIU, /**< Used in numpy.triu() implementation */
200201
DPNP_FN_TRUNC, /**< Used in numpy.trunc() implementation */

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,48 @@ void dpnp_ones_like_c(void* result, size_t size)
128128
}
129129

130130
template <typename _DataType>
131-
void dpnp_zeros_c(void* result, size_t size)
131+
class dpnp_tri_c_kernel;
132+
133+
template <typename _DataType>
134+
void dpnp_tri_c(void* result1, const size_t N, const size_t M, const int k)
132135
{
133-
_DataType* fill_value = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(sizeof(_DataType)));
134-
fill_value[0] = 0;
136+
cl::sycl::event event;
135137

136-
dpnp_initval_c<_DataType>(result, fill_value, size);
138+
if (!result1 || !N || !M)
139+
{
140+
return;
141+
}
137142

138-
dpnp_memory_free_c(fill_value);
139-
}
143+
_DataType* result = reinterpret_cast<_DataType*>(result1);
140144

141-
template <typename _DataType>
142-
void dpnp_zeros_like_c(void* result, size_t size)
143-
{
144-
dpnp_zeros_c<_DataType>(result, size);
145+
size_t idx = N* M;
146+
cl::sycl::range<1> gws(idx);
147+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
148+
size_t ind = global_id[0];
149+
size_t i = ind / M;
150+
size_t j = ind % M;
151+
152+
int val = i + k + 1;
153+
size_t diag_idx_ = (val > 0) ? (size_t)val : 0;
154+
size_t diag_idx = (M < diag_idx_) ? M : diag_idx_;
155+
156+
if (j < diag_idx)
157+
{
158+
result[ind] = 1;
159+
}
160+
else
161+
{
162+
result[ind] = 0;
163+
}
164+
};
165+
166+
auto kernel_func = [&](cl::sycl::handler& cgh) {
167+
cgh.parallel_for<class dpnp_tri_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
168+
};
169+
170+
event = DPNP_QUEUE.submit(kernel_func);
171+
172+
event.wait();
145173
}
146174

147175
template <typename _DataType>
@@ -353,6 +381,23 @@ void dpnp_triu_c(void* array_in,
353381
return;
354382
}
355383

384+
template <typename _DataType>
385+
void dpnp_zeros_c(void* result, size_t size)
386+
{
387+
_DataType* fill_value = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(sizeof(_DataType)));
388+
fill_value[0] = 0;
389+
390+
dpnp_initval_c<_DataType>(result, fill_value, size);
391+
392+
dpnp_memory_free_c(fill_value);
393+
}
394+
395+
template <typename _DataType>
396+
void dpnp_zeros_like_c(void* result, size_t size)
397+
{
398+
dpnp_zeros_c<_DataType>(result, size);
399+
}
400+
356401
void func_map_init_arraycreation(func_map_t& fmap)
357402
{
358403
fmap[DPNPFuncName::DPNP_FN_ARANGE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_arange_c<int>};
@@ -395,6 +440,21 @@ void func_map_init_arraycreation(func_map_t& fmap)
395440
fmap[DPNPFuncName::DPNP_FN_ONES_LIKE][eft_C128][eft_C128] = {eft_C128,
396441
(void*)dpnp_ones_like_c<std::complex<double>>};
397442

443+
fmap[DPNPFuncName::DPNP_FN_TRI][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_tri_c<int>};
444+
fmap[DPNPFuncName::DPNP_FN_TRI][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_tri_c<long>};
445+
fmap[DPNPFuncName::DPNP_FN_TRI][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tri_c<float>};
446+
fmap[DPNPFuncName::DPNP_FN_TRI][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_tri_c<double>};
447+
448+
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_tril_c<int>};
449+
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_tril_c<long>};
450+
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tril_c<float>};
451+
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_tril_c<double>};
452+
453+
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_triu_c<int>};
454+
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_triu_c<long>};
455+
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_triu_c<float>};
456+
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_triu_c<double>};
457+
398458
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_zeros_c<int>};
399459
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_zeros_c<long>};
400460
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_zeros_c<float>};
@@ -410,15 +470,5 @@ void func_map_init_arraycreation(func_map_t& fmap)
410470
fmap[DPNPFuncName::DPNP_FN_ZEROS_LIKE][eft_C128][eft_C128] = {eft_C128,
411471
(void*)dpnp_ones_like_c<std::complex<double>>};
412472

413-
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_tril_c<int>};
414-
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_tril_c<long>};
415-
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tril_c<float>};
416-
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_tril_c<double>};
417-
418-
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_triu_c<int>};
419-
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_triu_c<long>};
420-
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_triu_c<float>};
421-
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_triu_c<double>};
422-
423473
return;
424474
}

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
168168
DPNP_FN_TANH
169169
DPNP_FN_TRANSPOSE
170170
DPNP_FN_TRAPZ
171+
DPNP_FN_TRI
171172
DPNP_FN_TRIL
172173
DPNP_FN_TRIU
173174
DPNP_FN_TRUNC

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ __all__ += [
5959
]
6060

6161

62-
ctypedef void(*custom_1in_1out_func_ptr_t)(void *, void * , const int , size_t * , size_t * , const size_t, const size_t)
62+
ctypedef void(*custom_1in_1out_func_ptr_t)(void * , void * , const int , size_t * , size_t * , const size_t, const size_t)
63+
ctypedef void(*custom_indexing_1out_func_ptr_t)(void * , const size_t , const size_t , const int)
6364

6465

6566
cpdef dparray dpnp_copy(dparray x1, order, subok):
@@ -251,21 +252,26 @@ cpdef dparray dpnp_ones_like(result_shape, result_dtype):
251252
return call_fptr_1out(DPNP_FN_ONES_LIKE, result_shape, result_dtype)
252253

253254

254-
cpdef dparray dpnp_tri(N, M, k, dtype):
255-
cdef dparray result
256-
255+
cpdef dparray dpnp_tri(N, M=None, k=0, dtype=numpy.float):
257256
if M is None:
258257
M = N
259258

260-
result = dparray(shape=(N, M), dtype=dtype)
259+
if dtype == numpy.float:
260+
dtype = numpy.float64
261+
262+
cdef dparray result
263+
264+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(dtype)
265+
266+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRI, param1_type, param1_type)
267+
268+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
269+
270+
result = dparray(shape=(N, M), dtype=result_type)
271+
272+
cdef custom_indexing_1out_func_ptr_t func = <custom_indexing_1out_func_ptr_t > kernel_data.ptr
261273

262-
for i in range(N):
263-
diag_idx = max(0, i + k + 1)
264-
diag_idx = min(diag_idx, M)
265-
for j in range(diag_idx):
266-
result[i, j] = 1
267-
for j in range(diag_idx, M):
268-
result[i, j] = 0
274+
func(result.get_data(), N, M, k)
269275

270276
return result
271277

dpnp/dpnp_iface_arraycreation.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def ones_like(x1, dtype=None, order='C', subok=False, shape=None):
10361036
return numpy.ones_like(x1, dtype, order, subok, shape)
10371037

10381038

1039-
def tri(N, M=None, k=0, dtype=numpy.float):
1039+
def tri(N, M=None, k=0, dtype=numpy.float, **kwargs):
10401040
"""
10411041
An array with ones at and below the given diagonal and zeros elsewhere.
10421042
@@ -1058,9 +1058,22 @@ def tri(N, M=None, k=0, dtype=numpy.float):
10581058
"""
10591059

10601060
if not use_origin_backend():
1061-
return dpnp_tri(N, M, k, dtype)
1061+
if len(kwargs) != 0:
1062+
pass
1063+
elif not isinstance(N, int):
1064+
pass
1065+
elif N < 0:
1066+
pass
1067+
elif M is not None and not isinstance(M, int):
1068+
pass
1069+
elif M is not None and M < 0:
1070+
pass
1071+
elif not isinstance(k, int):
1072+
pass
1073+
else:
1074+
return dpnp_tri(N, M, k, dtype)
10621075

1063-
return call_origin(numpy.tri, N, M, k, dtype)
1076+
return call_origin(numpy.tri, N, M, k, dtype, **kwargs)
10641077

10651078

10661079
def tril(m, k=0):

tests/test_arraycreation.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,30 @@ def test_loadtxt(type):
138138
numpy.testing.assert_array_equal(dpnp_res, np_res)
139139

140140

141+
@pytest.mark.parametrize("N",
142+
[0, 1, 2, 3, 4],
143+
ids=['0', '1', '2', '3', '4'])
144+
@pytest.mark.parametrize("M",
145+
[0, 1, 2, 3, 4],
146+
ids=['0', '1', '2', '3', '4'])
147+
@pytest.mark.parametrize("k",
148+
[-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
149+
ids=['-5', '-4', '-3', '-2', '-1', '0', '1', '2', '3', '4', '5'])
150+
@pytest.mark.parametrize("type",
151+
[numpy.float64, numpy.float32, float, numpy.int64, numpy.int32, numpy.int, numpy.float, int],
152+
ids=['float64', 'float32', 'numpy.float', 'float', 'int64', 'int32', 'numpy.int', 'int'])
153+
def test_tri(N, M, k, type):
154+
expected = numpy.tri(N, M, k, dtype=type)
155+
result = dpnp.tri(N, M, k, dtype=type)
156+
numpy.testing.assert_array_equal(result, expected)
157+
158+
159+
def test_tri_default_dtype():
160+
expected = numpy.tri(3, 5, -1)
161+
result = dpnp.tri(3, 5, -1)
162+
numpy.testing.assert_array_equal(result, expected)
163+
164+
141165
@pytest.mark.parametrize("k",
142166
[-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6],
143167
ids=['-6', '-5', '-4', '-3', '-2', '-1', '0', '1', '2', '3', '4', '5', '6'])

0 commit comments

Comments
 (0)