Skip to content

Commit 4bc074b

Browse files
authored
Fix cholesky (#538)
* fix linalg.cholesky
1 parent 13497ef commit 4bc074b

File tree

5 files changed

+93
-45
lines changed

5 files changed

+93
-45
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,13 @@ INP_DLLEXPORT void dpnp_sort_c(void* array, void* result, size_t size);
240240
* @ingroup BACKEND_API
241241
* @brief math library implementation of cholesky function
242242
*
243-
* @param [in] array Input array with data.
244-
* @param [out] result Output array.
245-
* @param [in] shape Shape of input array.
243+
* @param [in] array Input array with data.
244+
* @param [out] result Output array.
245+
* @param [in] size Number of elements in input arrays.
246+
* @param [in] data_size Last element of shape arrays.
246247
*/
247248
template <typename _DataType>
248-
INP_DLLEXPORT void dpnp_cholesky_c(void* array1_in, void* result1, size_t* shape);
249+
INP_DLLEXPORT void dpnp_cholesky_c(void* array1_in, void* result1, const size_t size, const size_t data_size);
249250

250251
/**
251252
* @ingroup BACKEND_API

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,47 +34,78 @@
3434
namespace mkl_blas = oneapi::mkl::blas::row_major;
3535
namespace mkl_lapack = oneapi::mkl::lapack;
3636

37-
template <typename _DataType>
38-
class dpnp_cholesky_c_kernel;
3937

4038
template <typename _DataType>
41-
void dpnp_cholesky_c(void* array1_in, void* result1, size_t* shape)
39+
void dpnp_cholesky_c(void* array1_in, void* result1, const size_t size, const size_t data_size)
4240
{
43-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
44-
_DataType* l_result = reinterpret_cast<_DataType*>(result1);
41+
cl::sycl::event event;
4542

46-
size_t n = shape[0];
43+
_DataType* in_array = reinterpret_cast<_DataType*>(array1_in);
44+
_DataType* result = reinterpret_cast<_DataType*>(result1);
4745

48-
l_result[0] = sqrt(array_1[0]);
46+
size_t iters = size / (data_size * data_size);
4947

50-
for (size_t j = 1; j < n; j++)
48+
for (size_t k = 0; k < iters; ++k)
5149
{
52-
l_result[j * n] = array_1[j * n] / l_result[0];
53-
}
50+
_DataType matrix[data_size * data_size];
51+
_DataType result_[data_size * data_size];
5452

55-
for (size_t i = 1; i < n; i++)
56-
{
57-
_DataType sum_val = 0;
58-
for (size_t p = 0; p < i - 1; p++)
53+
for (size_t t = 0; t < data_size * data_size; ++t)
5954
{
60-
sum_val += l_result[i * n + p - 1] * l_result[i * n + p - 1];
55+
matrix[t] = in_array[k * (data_size * data_size) + t];
56+
6157
}
62-
l_result[i * n + i - 1] = sqrt(array_1[i * n + i - 1] - sum_val);
63-
}
6458

65-
for (size_t i = 1; i < n - 1; i++)
66-
{
67-
for (size_t j = i; j < n; j++)
59+
for (size_t it = 0; it < data_size * data_size; ++it)
6860
{
69-
_DataType sum_val = 0;
70-
for (size_t p = 0; p < i - 1; p++)
61+
result_[it] = matrix[it];
62+
}
63+
64+
const std::int64_t n = data_size;
65+
66+
const std::int64_t lda = std::max<size_t>(1UL, n);
67+
68+
const std::int64_t scratchpad_size = mkl_lapack::potrf_scratchpad_size<_DataType>(
69+
DPNP_QUEUE, oneapi::mkl::uplo::upper, n, lda);
70+
71+
_DataType* scratchpad = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(scratchpad_size * sizeof(_DataType)));
72+
73+
event = mkl_lapack::potrf(DPNP_QUEUE,
74+
oneapi::mkl::uplo::upper,
75+
n,
76+
result_,
77+
lda,
78+
scratchpad,
79+
scratchpad_size);
80+
81+
event.wait();
82+
83+
for (size_t i = 0; i < data_size; i++)
84+
{
85+
bool arg = false;
86+
for (size_t j = 0; j < data_size; j++)
7187
{
72-
sum_val += l_result[i * n + p - 1] * l_result[j * n + p - 1];
88+
if (i == j - 1)
89+
{
90+
arg = true;
91+
}
92+
if (arg)
93+
{
94+
result_[i * data_size + j] = 0;
95+
}
7396
}
74-
l_result[j * n + i - 1] = (1 / l_result[i * n + i - 1]) * (array_1[j * n + i - 1] - sum_val);
7597
}
98+
99+
dpnp_memory_free_c(scratchpad);
100+
101+
for (size_t t = 0; t < data_size * data_size; ++t)
102+
{
103+
result[k * (data_size * data_size) + t] = result_[t];
104+
105+
}
106+
76107
}
77-
return;
108+
78109
}
79110

80111
template <typename _DataType>
@@ -431,8 +462,6 @@ void dpnp_svd_c(void* array1_in, void* result1, void* result2, void* result3, si
431462

432463
void func_map_init_linalg_func(func_map_t& fmap)
433464
{
434-
fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_cholesky_c<int>};
435-
fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_cholesky_c<long>};
436465
fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_cholesky_c<float>};
437466
fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_cholesky_c<double>};
438467

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,30 @@ __all__ = [
5757
ctypedef void(*custom_linalg_1in_1out_func_ptr_t)(void *, void * , size_t * , size_t)
5858
ctypedef void(*custom_linalg_1in_1out_func_ptr_t_)(void * , void * , size_t * )
5959
ctypedef void(*custom_linalg_1in_1out_with_size_func_ptr_t_)(void *, void * , size_t)
60+
ctypedef void(*custom_linalg_1in_1out_with_2size_func_ptr_t_)(void *, void * , size_t, size_t)
6061
ctypedef void(*custom_linalg_1in_3out_shape_t)(void *, void * , void * , void * , size_t , size_t )
6162

6263

6364
cpdef dparray dpnp_cholesky(dparray input):
64-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
65+
if input.dtype == dpnp.int32 or input.dtype == dpnp.int64:
66+
input_ = input.astype(dpnp.float64)
67+
else:
68+
input_ = input
69+
70+
size_ = input_.shape[-1]
71+
72+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input_.dtype)
6573

6674
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_CHOLESKY, param1_type, param1_type)
6775

6876
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
69-
cdef dparray result = dparray(input.size, dtype=result_type)
77+
cdef dparray result = dparray(input_.shape, dtype=result_type)
78+
79+
cdef custom_linalg_1in_1out_with_2size_func_ptr_t_ func = <custom_linalg_1in_1out_with_2size_func_ptr_t_ > kernel_data.ptr
7080

71-
cdef custom_linalg_1in_1out_func_ptr_t_ func = <custom_linalg_1in_1out_func_ptr_t_ > kernel_data.ptr
81+
func(input_.get_data(), result.get_data(), input.size, size_)
7282

73-
func(input.get_data(), result.get_data(), < size_t * > input._dparray_shape.data())
74-
l_result = result.reshape(input.shape)
75-
return l_result
83+
return result
7684

7785

7886
cpdef dparray dpnp_cond(dparray input, p):

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,16 @@ def cholesky(input):
8686
Upper or lower-triangular Cholesky factor of `input`. Returns a
8787
matrix object if `input` is a matrix object.
8888
"""
89-
is_input_dparray = isinstance(input, dparray)
90-
91-
if not use_origin_backend(input) and is_input_dparray and input.ndim == 2 and \
92-
input.shape[0] == input.shape[1] and input.shape[0] > 0:
93-
result = dpnp_cholesky(input)
9489

95-
return result
90+
if not use_origin_backend(input):
91+
if not isinstance(input, dparray):
92+
pass
93+
elif input.shape[-1] != input.shape[-2]:
94+
pass
95+
elif input.ndim < 3:
96+
pass
97+
else:
98+
return dpnp_cholesky(input)
9699

97100
return call_origin(numpy.linalg.cholesky, input)
98101

tests/test_linalg.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,15 @@ def vvsort(val, vec, size):
2222
vec[k, imax] = temp
2323

2424

25-
def test_cholesky():
26-
a = numpy.array([[[1, -2], [2, 5]]])
25+
@pytest.mark.parametrize("array",
26+
[[[[1, -2], [2, 5]]],
27+
[[[1., -2.], [2., 5.]]],
28+
[[[1., -2.], [2., 5.]], [[1., -2.], [2., 5.]]]],
29+
ids=['[[[1, -2], [2, 5]]]',
30+
'[[[1., -2.], [2., 5.]]]',
31+
'[[[1., -2.], [2., 5.]], [[1., -2.], [2., 5.]]]'])
32+
def test_cholesky(array):
33+
a = numpy.array(array)
2734
ia = inp.array(a)
2835
result = inp.linalg.cholesky(ia)
2936
expected = numpy.linalg.cholesky(a)

0 commit comments

Comments
 (0)