Skip to content

Commit 25b75e3

Browse files
authored
Native triu (#627)
* native triu
1 parent 6c26d3d commit 25b75e3

File tree

7 files changed

+182
-23
lines changed

7 files changed

+182
-23
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,21 @@ template <typename _DataType>
582582
INP_DLLEXPORT void dpnp_tril_c(
583583
void* array, void* result, const int k, size_t* shape, size_t* res_shape, const size_t ndim, const size_t res_ndim);
584584

585+
/**
586+
* @ingroup BACKEND_API
587+
* @brief math library implementation of take function
588+
*
589+
* @param [in] array Input array with data.
590+
* @param [out] result Output array.
591+
* @param [in] k Diagonal above which to zero elements.
592+
* @param [in] shape Shape of input array.
593+
* @param [in] res_shape Shape of result array.
594+
* @param [in] ndim Number of elements in array.shape.
595+
* @param [in] res_ndim Number of elements in res_shape.
596+
*/
597+
template <typename _DataType>
598+
INP_DLLEXPORT void dpnp_triu_c(void* array, void* result, const int k, size_t* shape, size_t* res_shape, const size_t ndim, const size_t res_ndim);
599+
585600
/**
586601
* @ingroup BACKEND_API
587602
* @brief math library implementation of var function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ enum class DPNPFuncName : size_t
191191
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() implementation */
192192
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() implementation */
193193
DPNP_FN_TRIL, /**< Used in numpy.tril() implementation */
194+
DPNP_FN_TRIU, /**< Used in numpy.triu() implementation */
194195
DPNP_FN_TRUNC, /**< Used in numpy.trunc() implementation */
195196
DPNP_FN_VAR, /**< Used in numpy.var() implementation */
196197
DPNP_FN_ZEROS, /**< Used in numpy.zeros() implementation */

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,35 @@ void dpnp_tril_c(void* array_in,
135135
const size_t ndim,
136136
const size_t res_ndim)
137137
{
138+
if ((array_in == nullptr) || (result1 == nullptr))
139+
{
140+
return;
141+
}
142+
138143
_DataType* array_m = reinterpret_cast<_DataType*>(array_in);
139144
_DataType* result = reinterpret_cast<_DataType*>(result1);
140145

146+
if ((shape == nullptr) || (res_shape == nullptr))
147+
{
148+
return;
149+
}
150+
151+
if ((ndim == 0) || (res_ndim == 0))
152+
{
153+
return;
154+
}
155+
141156
size_t res_size = 1;
142157
for (size_t i = 0; i < res_ndim; ++i)
143158
{
144159
res_size *= res_shape[i];
145160
}
146161

162+
if (res_size == 0)
163+
{
164+
return;
165+
}
166+
147167
if (ndim == 1)
148168
{
149169
for (size_t i = 0; i < res_size; ++i)
@@ -211,6 +231,104 @@ void dpnp_tril_c(void* array_in,
211231
return;
212232
}
213233

234+
template <typename _DataType>
235+
void dpnp_triu_c(void* array_in, void* result1, const int k, size_t* shape, size_t* res_shape, const size_t ndim, const size_t res_ndim)
236+
{
237+
if ((array_in == nullptr) || (result1 == nullptr))
238+
{
239+
return;
240+
}
241+
_DataType* array_m = reinterpret_cast<_DataType*>(array_in);
242+
_DataType* result = reinterpret_cast<_DataType*>(result1);
243+
244+
if ((shape == nullptr) || (res_shape == nullptr))
245+
{
246+
return;
247+
}
248+
249+
if ((ndim == 0) || (res_ndim == 0))
250+
{
251+
return;
252+
}
253+
254+
size_t res_size = 1;
255+
for (size_t i = 0; i < res_ndim; ++i)
256+
{
257+
res_size *= res_shape[i];
258+
}
259+
260+
if (res_size == 0)
261+
{
262+
return;
263+
}
264+
265+
if (ndim == 1)
266+
{
267+
for (size_t i = 0; i < res_size; ++i)
268+
{
269+
size_t n = res_size;
270+
size_t val = i;
271+
int ids[res_ndim];
272+
for (size_t j = 0; j < res_ndim; ++j)
273+
{
274+
n /= res_shape[j];
275+
size_t p = val / n;
276+
ids[j] = p;
277+
if (p != 0)
278+
{
279+
val = val - p * n;
280+
}
281+
}
282+
283+
int diag_idx_ = (ids[res_ndim - 2] + k > -1) ? (ids[res_ndim - 2] + k) : -1;
284+
int values = res_shape[res_ndim - 1];
285+
int diag_idx = (values < diag_idx_) ? values : diag_idx_;
286+
287+
if (ids[res_ndim - 1] >= diag_idx)
288+
{
289+
result[i] = array_m[ids[res_ndim - 1]];
290+
}
291+
else
292+
{
293+
result[i] = 0;
294+
}
295+
}
296+
}
297+
else
298+
{
299+
for (size_t i = 0; i < res_size; ++i)
300+
{
301+
size_t n = res_size;
302+
size_t val = i;
303+
int ids[res_ndim];
304+
for (size_t j = 0; j < res_ndim; ++j)
305+
{
306+
n /= res_shape[j];
307+
size_t p = val / n;
308+
ids[j] = p;
309+
if (p != 0)
310+
{
311+
val = val - p * n;
312+
}
313+
}
314+
315+
int diag_idx_ = (ids[res_ndim - 2] + k > -1) ? (ids[res_ndim - 2] + k) : -1;
316+
int values = res_shape[res_ndim - 1];
317+
int diag_idx = (values < diag_idx_) ? values : diag_idx_;
318+
319+
if (ids[res_ndim - 1] >= diag_idx)
320+
{
321+
result[i] = array_m[i];
322+
}
323+
else
324+
{
325+
result[i] = 0;
326+
}
327+
}
328+
}
329+
return;
330+
}
331+
214332
void func_map_init_arraycreation(func_map_t& fmap)
215333
{
216334
fmap[DPNPFuncName::DPNP_FN_ARANGE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_arange_c<int>};
@@ -247,5 +365,10 @@ void func_map_init_arraycreation(func_map_t& fmap)
247365
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tril_c<float>};
248366
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_tril_c<double>};
249367

368+
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_triu_c<int>};
369+
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_triu_c<long>};
370+
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_triu_c<float>};
371+
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_triu_c<double>};
372+
250373
return;
251374
}

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
164164
DPNP_FN_TRANSPOSE
165165
DPNP_FN_TRAPZ
166166
DPNP_FN_TRIL
167+
DPNP_FN_TRIU
167168
DPNP_FN_TRUNC
168169
DPNP_FN_VAR
169170
DPNP_FN_ZEROS

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -263,35 +263,21 @@ cpdef dparray dpnp_tril(dparray m, int k):
263263
return result
264264

265265

266-
cpdef dparray dpnp_triu(m, k):
267-
cdef dparray result
266+
cpdef dparray dpnp_triu(dparray m, int k):
268267
if m.ndim == 1:
268+
res_shape=(m.shape[0], m.shape[0])
269+
else:
270+
res_shape=m.shape
269271

270-
result = dparray(shape=(m.shape[0], m.shape[0]), dtype=m.dtype)
271-
272-
for i in range(result.size):
273-
ids = get_axis_indeces(i, result.shape)
274-
275-
diag_idx = max(-1, ids[result.ndim - 2] + k)
276-
diag_idx = min(diag_idx, result.shape[result.ndim - 1])
272+
cdef dparray result = dparray(shape=res_shape, dtype=m.dtype)
277273

278-
if ids[result.ndim - 1] >= diag_idx:
279-
result[i] = m[ids[result.ndim - 1]]
280-
else:
281-
result[i] = 0
282-
else:
283-
result = dparray(shape=m.shape, dtype=m.dtype)
274+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(m.dtype)
284275

285-
for i in range(result.size):
286-
ids = get_axis_indeces(i, result.shape)
276+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRIU, param1_type, param1_type)
287277

288-
diag_idx = max(-1, ids[result.ndim - 2] + k)
289-
diag_idx = min(diag_idx, result.shape[result.ndim - 1])
278+
cdef custom_1in_1out_func_ptr_t func = <custom_1in_1out_func_ptr_t > kernel_data.ptr
290279

291-
if ids[result.ndim - 1] >= diag_idx:
292-
result[i] = m[i]
293-
else:
294-
result[i] = 0
280+
func(m.get_data(), result.get_data(), k, < size_t * > m._dparray_shape.data(), < size_t * > result._dparray_shape.data(), m.ndim, result.ndim)
295281

296282
return result
297283

dpnp/dpnp_iface_arraycreation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,8 @@ def triu(m, k=0):
11141114
if not use_origin_backend(m):
11151115
if not isinstance(m, dparray):
11161116
pass
1117+
elif not isinstance(k, int):
1118+
pass
11171119
else:
11181120
return dpnp_triu(m, k)
11191121

tests/test_arraycreation.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,34 @@ def test_tril(m, k):
162162
expected = numpy.tril(a, k)
163163
result = dpnp.tril(ia, k)
164164
numpy.testing.assert_array_equal(expected, result)
165+
166+
167+
@pytest.mark.parametrize("k",
168+
[-4, -3, -2, -1, 0, 1, 2, 3, 4],
169+
ids=['-4', '-3', '-2', '-1', '0', '1', '2', '3', '4'])
170+
@pytest.mark.parametrize("m",
171+
[[0, 1, 2, 3, 4],
172+
[[1, 2], [3, 4]],
173+
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
174+
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]],
175+
ids=['[0, 1, 2, 3, 4]',
176+
'[[1, 2], [3, 4]]',
177+
'[[0, 1, 2], [3, 4, 5], [6, 7, 8]]',
178+
'[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]'])
179+
def test_triu(m, k):
180+
a = numpy.array(m)
181+
ia = dpnp.array(a)
182+
expected = numpy.triu(a, k)
183+
result = dpnp.triu(ia, k)
184+
numpy.testing.assert_array_equal(expected, result)
185+
186+
187+
@pytest.mark.parametrize("k",
188+
[-4, -3, -2, -1, 0, 1, 2, 3, 4],
189+
ids=['-4', '-3', '-2', '-1', '0', '1', '2', '3', '4'])
190+
def test_triu_size_null(k):
191+
a = numpy.ones(shape=(1, 2, 0))
192+
ia = dpnp.array(a)
193+
expected = numpy.triu(a, k)
194+
result = dpnp.triu(ia, k)
195+
numpy.testing.assert_array_equal(expected, result)

0 commit comments

Comments
 (0)