Skip to content

Commit b41271f

Browse files
authored
linalg module is free of dparray (#875)
1 parent 4284b40 commit b41271f

File tree

2 files changed

+47
-63
lines changed

2 files changed

+47
-63
lines changed

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import dpnp
3636
cimport dpnp.dpnp_utils as utils
3737
import dpnp.dpnp_utils as utils_py
3838
from dpnp.dpnp_algo cimport *
39-
from dpnp.dparray cimport dparray
39+
4040
import numpy
4141
cimport numpy
4242

@@ -202,7 +202,7 @@ cpdef utils.dpnp_descriptor dpnp_matrix_rank(utils.dpnp_descriptor input):
202202
return result
203203

204204

205-
cpdef dparray dpnp_norm(dparray input, ord=None, axis=None):
205+
cpdef object dpnp_norm(object input, ord=None, axis=None):
206206
cdef long size_input = input.size
207207
cdef shape_type_c shape_input = input.shape
208208

@@ -246,14 +246,14 @@ cpdef dparray dpnp_norm(dparray input, ord=None, axis=None):
246246
else:
247247
absx = dpnp.abs(input)
248248
absx_size = absx.size
249-
absx_power = dparray(absx_size, dtype=absx.dtype)
249+
absx_power = utils_py.create_output_descriptor_py((absx_size,), absx.dtype, None).get_pyobj()
250250
for i in range(absx_size):
251251
absx_elem = absx.item(i)
252252
absx_power[i] = absx_elem ** ord
253253
absx_ = absx_power.reshape(absx.shape)
254254
ret = dpnp.sum(absx_, axis=axis)
255255
ret_size = ret.size
256-
ret_power = dparray(ret_size)
256+
ret_power = utils_py.create_output_descriptor_py((ret_size,), None, None).get_pyobj()
257257
for i in range(ret_size):
258258
ret_elem = ret.item(i)
259259
ret_power[i] = ret_elem ** (1 / ord)
@@ -270,86 +270,73 @@ cpdef dparray dpnp_norm(dparray input, ord=None, axis=None):
270270
elif ord == 1:
271271
if col_axis > row_axis:
272272
col_axis -= 1
273-
dpnp_sum_val_ = dpnp.sum(dpnp.abs(input), axis=row_axis)
274-
dpnp_sum_val = dpnp_sum_val_ if isinstance(dpnp_sum_val_, dparray) else dpnp.array([dpnp_sum_val_])
275-
dpnp_max_val = dpnp_sum_val.min(axis=col_axis)
276-
ret = dpnp_max_val if isinstance(dpnp_max_val, dparray) else dpnp.array([dpnp_max_val])
273+
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=row_axis)
274+
ret = dpnp_sum_val.min(axis=col_axis)
277275
elif ord == numpy.inf:
278276
if row_axis > col_axis:
279277
row_axis -= 1
280-
dpnp_sum_val_ = dpnp.sum(dpnp.abs(input), axis=col_axis)
281-
dpnp_sum_val = dpnp_sum_val_ if isinstance(dpnp_sum_val_, dparray) else dpnp.array([dpnp_sum_val_])
282-
dpnp_max_val = dpnp_sum_val.max(axis=row_axis)
283-
ret = dpnp_max_val if isinstance(dpnp_max_val, dparray) else dpnp.array([dpnp_max_val])
278+
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=col_axis)
279+
ret = dpnp_sum_val.max(axis=row_axis)
284280
elif ord == -1:
285281
if col_axis > row_axis:
286282
col_axis -= 1
287-
dpnp_sum_val_ = dpnp.sum(dpnp.abs(input), axis=row_axis)
288-
dpnp_sum_val = dpnp_sum_val_ if isinstance(dpnp_sum_val_, dparray) else dpnp.array([dpnp_sum_val_])
289-
dpnp_min_val = dpnp_sum_val.min(axis=col_axis)
290-
ret = dpnp_min_val if isinstance(dpnp_min_val, dparray) else dpnp.array([dpnp_min_val])
283+
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=row_axis)
284+
ret = dpnp_sum_val.min(axis=col_axis)
291285
elif ord == -numpy.inf:
292286
if row_axis > col_axis:
293287
row_axis -= 1
294-
dpnp_sum_val_ = dpnp.sum(dpnp.abs(input), axis=col_axis)
295-
dpnp_sum_val = dpnp_sum_val_ if isinstance(dpnp_sum_val_, dparray) else dpnp.array([dpnp_sum_val_])
296-
dpnp_min_val = dpnp_sum_val.min(axis=row_axis)
297-
ret = dpnp_min_val if isinstance(dpnp_min_val, dparray) else dpnp.array([dpnp_min_val])
288+
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=col_axis)
289+
ret = dpnp_sum_val.min(axis=row_axis)
298290
elif ord in [None, 'fro', 'f']:
299291
ret = dpnp.sqrt(dpnp.sum(input * input, axis=axis))
300292
# elif ord == 'nuc':
301293
# ret = _multi_svd_norm(input, row_axis, col_axis, sum)
302294
else:
303295
raise ValueError("Invalid norm order for matrices.")
296+
304297
return ret
305298
else:
306299
raise ValueError("Improper number of dimensions to norm.")
307300

308301

309-
cpdef tuple dpnp_qr(dparray x1, mode):
302+
cpdef tuple dpnp_qr(utils.dpnp_descriptor x1, str mode):
310303
cdef size_t size_m = x1.shape[0]
311304
cdef size_t size_n = x1.shape[1]
305+
cdef size_t size_tau = min(size_m, size_n)
312306

313307
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
314308
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_QR, param1_type, param1_type)
315309

316-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
317-
318-
cdef dparray res_q = dparray((size_m, size_m), dtype=result_type)
319-
cdef dparray res_r = dparray((size_m, size_n), dtype=result_type)
320-
321-
size_tau = min(size_m, size_n)
322-
cdef dparray tau = dparray((size_tau, ), dtype=result_type)
310+
cdef utils.dpnp_descriptor res_q = utils.create_output_descriptor((size_m, size_m), kernel_data.return_type, None)
311+
cdef utils.dpnp_descriptor res_r = utils.create_output_descriptor((size_m, size_n), kernel_data.return_type, None)
312+
cdef utils.dpnp_descriptor tau = utils.create_output_descriptor((size_tau, ), kernel_data.return_type, None)
323313

324314
cdef custom_linalg_1in_3out_shape_t func = < custom_linalg_1in_3out_shape_t > kernel_data.ptr
325315

326316
func(x1.get_data(), res_q.get_data(), res_r.get_data(), tau.get_data(), size_m, size_n)
327317

328-
return (res_q, res_r)
318+
return (res_q.get_pyobj(), res_r.get_pyobj())
329319

330320

331-
cpdef tuple dpnp_svd(dparray x1, full_matrices, compute_uv, hermitian):
321+
cpdef tuple dpnp_svd(utils.dpnp_descriptor x1, cpp_bool full_matrices, cpp_bool compute_uv, cpp_bool hermitian):
332322
cdef size_t size_m = x1.shape[0]
333323
cdef size_t size_n = x1.shape[1]
324+
cdef size_t size_s = min(size_m, size_n)
334325

335326
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
336327
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_SVD, param1_type, param1_type)
337328

338-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
339-
329+
cdef DPNPFuncType type_s = DPNP_FT_DOUBLE
340330
if x1.dtype == dpnp.float32:
341-
type_s = dpnp.float32
342-
else:
343-
type_s = dpnp.float64
331+
type_s = DPNP_FT_FLOAT
344332

345-
size_s = min(size_m, size_n)
346333

347-
cdef dparray res_u = dparray((size_m, size_m), dtype=result_type)
348-
cdef dparray res_s = dparray((size_s, ), dtype=type_s)
349-
cdef dparray res_vt = dparray((size_n, size_n), dtype=result_type)
334+
cdef utils.dpnp_descriptor res_u = utils.create_output_descriptor((size_m, size_m), kernel_data.return_type, None)
335+
cdef utils.dpnp_descriptor res_s = utils.create_output_descriptor((size_s, ), type_s, None)
336+
cdef utils.dpnp_descriptor res_vt = utils.create_output_descriptor((size_n, size_n), kernel_data.return_type, None)
350337

351338
cdef custom_linalg_1in_3out_shape_t func = < custom_linalg_1in_3out_shape_t > kernel_data.ptr
352339

353340
func(x1.get_data(), res_u.get_data(), res_s.get_data(), res_vt.get_data(), size_m, size_n)
354341

355-
return (res_u, res_s, res_vt)
342+
return (res_u.get_pyobj(), res_s.get_pyobj(), res_vt.get_pyobj())

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import dpnp
4444
import numpy
4545

46-
from dpnp.dparray import dparray
4746
from dpnp.dpnp_utils import *
4847
from dpnp.linalg.dpnp_algo_linalg import *
4948

@@ -231,7 +230,7 @@ def matrix_power(input, count):
231230
232231
Returns
233232
-------
234-
output : dparray
233+
output : array
235234
Returns the dot product of the supplied arrays.
236235
237236
See Also
@@ -328,7 +327,7 @@ def multi_dot(arrays, out=None):
328327
return result
329328

330329

331-
def norm(input, ord=None, axis=None, keepdims=False):
330+
def norm(x1, ord=None, axis=None, keepdims=False):
332331
"""
333332
Matrix or vector norm.
334333
This function is able to return one of eight different matrix norms,
@@ -362,22 +361,21 @@ def norm(input, ord=None, axis=None, keepdims=False):
362361
Norm of the matrix or vector(s).
363362
"""
364363

365-
if not use_origin_backend(input):
366-
if not isinstance(input, dparray):
367-
pass
368-
elif not isinstance(axis, int) and not isinstance(axis, tuple) and axis is not None:
364+
x1_desc = dpnp.get_dpnp_descriptor(x1)
365+
if x1_desc:
366+
if not isinstance(axis, int) and not isinstance(axis, tuple) and axis is not None:
369367
pass
370368
elif keepdims is not False:
371369
pass
372370
elif ord not in [None, 0, 3, 'fro', 'f']:
373371
pass
374372
else:
375-
result_obj = dpnp_norm(input, ord=ord, axis=axis)
373+
result_obj = dpnp_norm(x1, ord=ord, axis=axis)
376374
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
377375

378376
return result
379377

380-
return call_origin(numpy.linalg.norm, input, ord, axis, keepdims)
378+
return call_origin(numpy.linalg.norm, x1, ord, axis, keepdims)
381379

382380

383381
def qr(x1, mode='reduced'):
@@ -396,21 +394,19 @@ def qr(x1, mode='reduced'):
396394
397395
"""
398396

399-
if not use_origin_backend(x1):
400-
if not isinstance(x1, dparray):
401-
pass
402-
elif mode != 'reduced':
397+
x1_desc = dpnp.get_dpnp_descriptor(x1)
398+
if x1_desc:
399+
if mode != 'reduced':
403400
pass
404401
else:
405-
# I see something wrong with it. it is couse SIGSEGV in 1 of 10 test times
406-
res_q, res_r = dpnp_qr(x1, mode)
402+
result_tup = dpnp_qr(x1, mode)
407403

408-
return (res_q, res_r)
404+
return result_tup
409405

410406
return call_origin(numpy.linalg.qr, x1, mode)
411407

412408

413-
def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
409+
def svd(x1, full_matrices=True, compute_uv=True, hermitian=False):
414410
"""
415411
Singular Value Decomposition.
416412
@@ -467,10 +463,9 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
467463
468464
"""
469465

470-
if not use_origin_backend(a):
471-
if not isinstance(a, dparray):
472-
pass
473-
elif not a.ndim == 2:
466+
x1_desc = dpnp.get_dpnp_descriptor(x1)
467+
if x1_desc:
468+
if not x1_desc.ndim == 2:
474469
pass
475470
elif not full_matrices == True:
476471
pass
@@ -479,6 +474,8 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
479474
elif not hermitian == False:
480475
pass
481476
else:
482-
return dpnp_svd(a, full_matrices, compute_uv, hermitian)
477+
result_tup = dpnp_svd(x1_desc, full_matrices, compute_uv, hermitian)
478+
479+
return result_tup
483480

484-
return call_origin(numpy.linalg.svd, a, full_matrices, compute_uv, hermitian)
481+
return call_origin(numpy.linalg.svd, x1, full_matrices, compute_uv, hermitian)

0 commit comments

Comments
 (0)