Skip to content

Commit 4284b40

Browse files
authored
Linalg part2 (#874)
* inv kernel fix
1 parent de6e7b3 commit 4284b40

File tree

4 files changed

+32
-34
lines changed

4 files changed

+32
-34
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ INP_DLLEXPORT void dpnp_initval_c(void* result1, void* value, size_t size);
622622
* @param [in] shape Shape of input array.
623623
* @param [in] ndim Number of elements in shape.
624624
*/
625-
template <typename _DataType>
625+
template <typename _DataType, typename _ResultType>
626626
INP_DLLEXPORT void dpnp_inv_c(void* array1_in, void* result1, size_t* shape, size_t ndim);
627627

628628
/**

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,17 @@ void dpnp_det_c(void* array1_in, void* result1, size_t* shape, size_t ndim)
190190
return;
191191
}
192192

193-
template <typename _DataType>
193+
template <typename _DataType, typename _ResultType>
194194
void dpnp_inv_c(void* array1_in, void* result1, size_t* shape, size_t ndim)
195195
{
196196
(void)ndim; // avoid warning unused variable
197197
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
198-
_DataType* result = reinterpret_cast<_DataType*>(result1);
198+
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
199199

200200
size_t n = shape[0];
201201

202-
_DataType a_arr[n][n];
203-
_DataType e_arr[n][n];
202+
_ResultType a_arr[n][n];
203+
_ResultType e_arr[n][n];
204204

205205
for (size_t i = 0; i < n; ++i)
206206
{
@@ -541,10 +541,10 @@ void func_map_init_linalg_func(func_map_t& fmap)
541541
fmap[DPNPFuncName::DPNP_FN_DET][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_det_c<float>};
542542
fmap[DPNPFuncName::DPNP_FN_DET][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_det_c<double>};
543543

544-
fmap[DPNPFuncName::DPNP_FN_INV][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_inv_c<int>};
545-
fmap[DPNPFuncName::DPNP_FN_INV][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_inv_c<long>};
546-
fmap[DPNPFuncName::DPNP_FN_INV][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_inv_c<float>};
547-
fmap[DPNPFuncName::DPNP_FN_INV][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_inv_c<double>};
544+
fmap[DPNPFuncName::DPNP_FN_INV][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_inv_c<int, double>};
545+
fmap[DPNPFuncName::DPNP_FN_INV][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_inv_c<long, double>};
546+
fmap[DPNPFuncName::DPNP_FN_INV][eft_FLT][eft_FLT] = {eft_DBL, (void*)dpnp_inv_c<float, double>};
547+
fmap[DPNPFuncName::DPNP_FN_INV][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_inv_c<double, double>};
548548

549549
fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_kron_c<int, int, int>};
550550
fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_kron_c<int, long, long>};

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ and the rest of the library
3434

3535
import dpnp
3636
cimport dpnp.dpnp_utils as utils
37+
import dpnp.dpnp_utils as utils_py
3738
from dpnp.dpnp_algo cimport *
3839
from dpnp.dparray cimport dparray
3940
import numpy
@@ -80,7 +81,7 @@ cpdef utils.dpnp_descriptor dpnp_cholesky(utils.dpnp_descriptor input_):
8081
return result
8182

8283

83-
cpdef dparray dpnp_cond(dparray input, p):
84+
cpdef object dpnp_cond(object input, object p):
8485
if p in ('f', 'fro'):
8586
input = input.ravel(order='K')
8687
sqnorm = dpnp.dot(input, input)
@@ -130,7 +131,7 @@ cpdef utils.dpnp_descriptor dpnp_det(utils.dpnp_descriptor input):
130131
return result
131132

132133

133-
cpdef tuple dpnp_eig(dparray x1):
134+
cpdef tuple dpnp_eig(utils.dpnp_descriptor x1):
134135
cdef shape_type_c x1_shape = x1.shape
135136

136137
cdef size_t size = 0 if x1_shape.empty() else x1_shape.front()
@@ -140,14 +141,14 @@ cpdef tuple dpnp_eig(dparray x1):
140141

141142
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
142143

143-
cdef dparray res_val = dparray((size,), dtype=result_type)
144-
cdef dparray res_vec = dparray(x1_shape, dtype=result_type)
144+
cdef utils.dpnp_descriptor res_val = utils.create_output_descriptor((size,), kernel_data.return_type, None)
145+
cdef utils.dpnp_descriptor res_vec = utils.create_output_descriptor(x1_shape, kernel_data.return_type, None)
145146

146147
cdef custom_linalg_2in_1out_func_ptr_t func = <custom_linalg_2in_1out_func_ptr_t > kernel_data.ptr
147148
# call FPTR function
148149
func(x1.get_data(), res_val.get_data(), res_vec.get_data(), size)
149150

150-
return (res_val, res_vec)
151+
return (res_val.get_pyobj(), res_vec.get_pyobj())
151152

152153

153154
cpdef utils.dpnp_descriptor dpnp_eigvals(utils.dpnp_descriptor input):
@@ -168,22 +169,21 @@ cpdef utils.dpnp_descriptor dpnp_eigvals(utils.dpnp_descriptor input):
168169
return res_val
169170

170171

171-
cpdef dparray dpnp_inv(dparray input_):
172-
cdef dparray input = input_.astype(dpnp.float64)
172+
cpdef utils.dpnp_descriptor dpnp_inv(utils.dpnp_descriptor input):
173173
cdef shape_type_c input_shape = input.shape
174174

175175
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
176176

177177
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_INV, param1_type, param1_type)
178178

179-
cdef dparray result = dparray(input.size, dtype=dpnp.float64)
179+
# ceate result array with type given by FPTR data
180+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(input_shape, kernel_data.return_type, None)
180181

181182
cdef custom_linalg_1in_1out_func_ptr_t func = <custom_linalg_1in_1out_func_ptr_t > kernel_data.ptr
182183

183184
func(input.get_data(), result.get_data(), < size_t * > input_shape.data(), input.ndim)
184185

185-
dpnp_result = result.reshape(input.shape)
186-
return dpnp_result
186+
return result
187187

188188

189189
cpdef utils.dpnp_descriptor dpnp_matrix_rank(utils.dpnp_descriptor input):

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ def cond(input, p=None):
118118
:obj:`dpnp.norm` : Matrix or vector norm.
119119
"""
120120

121-
is_input_dparray = isinstance(input, dparray)
122-
123-
if (not use_origin_backend(input) and is_input_dparray):
121+
if (not use_origin_backend(input)):
124122
if p in [None, 1, -1, 2, -2, numpy.inf, -numpy.inf, 'fro']:
125-
result = dpnp_cond(input, p=p)
126-
return result.dtype.type(result[0])
123+
result_obj = dpnp_cond(input, p)
124+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
125+
126+
return result
127127
else:
128128
pass
129129

@@ -164,11 +164,10 @@ def eig(x1):
164164
165165
"""
166166

167-
is_x1_dparray = isinstance(x1, dparray)
168-
169-
if (not use_origin_backend(x1) and is_x1_dparray):
170-
if (x1.size > 0):
171-
return dpnp_eig(x1)
167+
x1_desc = dpnp.get_dpnp_descriptor(x1)
168+
if x1_desc:
169+
if (x1_desc.size > 0):
170+
return dpnp_eig(x1_desc)
172171

173172
return call_origin(numpy.linalg.eig, x1)
174173

@@ -214,11 +213,10 @@ def inv(input):
214213
Otherwise the function will be executed sequentially on CPU.
215214
"""
216215

217-
is_input_dparray = isinstance(input, dparray)
218-
219-
if (not use_origin_backend(input) and is_input_dparray):
220-
if input.ndim == 2 and input.shape[0] == input.shape[1] and input.shape[0] >= 2:
221-
return dpnp_inv(input)
216+
x1_desc = dpnp.get_dpnp_descriptor(input)
217+
if x1_desc:
218+
if x1_desc.ndim == 2 and x1_desc.shape[0] == x1_desc.shape[1] and x1_desc.shape[0] >= 2:
219+
return dpnp_inv(x1_desc).get_pyobj()
222220

223221
return call_origin(numpy.linalg.inv, input)
224222

0 commit comments

Comments
 (0)