Skip to content

Commit f7058c0

Browse files
dpnp array container, init impl (#993)
* dpnp array container, init impl * dpnp_ndarray funcs: __add__, __len__, __mod__, __radd__, __rmatmul__, __rmod__, __rmul__, __rtruediv__, __str__, __truediv__, itemsize, std, strides * fix astype func * fix argsort result shape for input with ndim==0 * fix partition function: in case ndim==0 pass to fallback
1 parent 61ae592 commit f7058c0

File tree

8 files changed

+559
-10
lines changed

8 files changed

+559
-10
lines changed

dpnp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
if "numpy" in sys.modules:
4141
warnings.warn("\nDPNP: Module NumPy found. Please load DPNP module before NumPy.\n")
4242

43-
from dpnp.dparray import dparray as ndarray
43+
# from dpnp.dparray import dparray as ndarray
44+
from dpnp.dpnp_array import dpnp_array as ndarray
4445
from dpnp.dpnp_iface import *
4546
from dpnp.dpnp_iface import __all__ as _iface__all__
4647
from dpnp.dpnp_iface_types import *

dpnp/dpnp_algo/dpnp_algo_manipulation.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ cpdef utils.dpnp_descriptor dpnp_repeat(utils.dpnp_descriptor array1, repeats, a
148148

149149
cpdef utils.dpnp_descriptor dpnp_reshape(utils.dpnp_descriptor array1, newshape, order=None):
150150
# return dpnp.get_dpnp_descriptor(dpctl.tensor.usm_ndarray(newshape, dtype=numpy.dtype(array1.dtype).name, buffer=array1.get_pyobj()))
151-
return dpnp.get_dpnp_descriptor(dpctl.tensor.reshape(array1.get_pyobj(), newshape))
151+
# return dpnp.get_dpnp_descriptor(dpctl.tensor.reshape(array1.get_pyobj(), newshape))
152+
return dpnp.get_dpnp_descriptor(dpctl.tensor.reshape(array1.get_pyobj()._array_obj, newshape))
152153

153154

154155
cpdef utils.dpnp_descriptor dpnp_transpose(utils.dpnp_descriptor array1, axes=None):

dpnp/dpnp_algo/dpnp_algo_sorting.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ ctypedef void(*fptr_dpnp_searchsorted_t)(void * , const void * , const void * ,
4747

4848

4949
cpdef utils.dpnp_descriptor dpnp_argsort(utils.dpnp_descriptor x1):
50-
return call_fptr_1in_1out(DPNP_FN_ARGSORT, x1, x1.shape)
50+
cdef shape_type_c result_shape = x1.shape
51+
if result_shape == ():
52+
result_shape = (1,)
53+
return call_fptr_1in_1out(DPNP_FN_ARGSORT, x1, result_shape)
5154

5255

5356
cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, axis=-1, kind='introselect', order=None):

0 commit comments

Comments
 (0)