Skip to content

Commit a7ee281

Browse files
Implement dpnp.astype (#961)
* Implement dpnp.astype Co-authored-by: Alexander-Makaryev <[email protected]>
1 parent 1a87f91 commit a7ee281

File tree

5 files changed

+41
-11
lines changed

5 files changed

+41
-11
lines changed

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ cpdef tuple dpnp_linspace(start, stop, num, endpoint, retstep, dtype, axis):
212212
return (result.get_pyobj(), step)
213213

214214

215-
cpdef object dpnp_logspace(start, stop, num, endpoint, base, dtype, axis):
215+
cpdef utils.dpnp_descriptor dpnp_logspace(start, stop, num, endpoint, base, dtype, axis):
216216
temp = dpnp.linspace(start, stop, num=num, endpoint=endpoint)
217-
return dpnp_astype(dpnp.get_dpnp_descriptor(dpnp.power(base, temp)), dtype)
217+
return dpnp.get_dpnp_descriptor(dpnp.astype(dpnp.power(base, temp), dtype))
218218

219219

220220
cpdef list dpnp_meshgrid(xi, copy, sparse, indexing):

dpnp/dpnp_iface.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
__all__ = [
5656
"array_equal",
5757
"asnumpy",
58+
"astype",
5859
"convert_single_elem_array_to_scalar",
5960
"dpnp_queue_initialize",
6061
"dpnp_queue_is_cpu",
@@ -139,6 +140,34 @@ def asnumpy(input, order='C'):
139140
return numpy.asarray(input, order=order)
140141

141142

143+
def astype(x1, dtype, order='K', casting='unsafe', subok=True, copy=True):
144+
"""Copy the array with data type casting."""
145+
if config.__DPNP_OUTPUT_DPCTL__ and hasattr(x1, "__sycl_usm_array_interface__"):
146+
import dpctl.tensor as dpt
147+
# TODO: remove check dpctl.tensor has attribute "astype"
148+
if hasattr(dpt, "astype"):
149+
return dpt.astype(x1, dtype, order=order, casting=casting, copy=copy)
150+
151+
x1_desc = get_dpnp_descriptor(x1)
152+
if not x1_desc:
153+
pass
154+
elif order != 'K':
155+
pass
156+
elif casting != 'unsafe':
157+
pass
158+
elif not subok:
159+
pass
160+
elif not copy:
161+
pass
162+
elif x1_desc.dtype == numpy.complex128 or dtype == numpy.complex128:
163+
pass
164+
elif x1_desc.dtype == numpy.complex64 or dtype == numpy.complex64:
165+
pass
166+
else:
167+
return dpnp_astype(x1_desc, dtype).get_pyobj()
168+
169+
return call_origin(numpy.ndarray.astype, x1, dtype, order=order, casting=casting, subok=subok, copy=copy)
170+
142171
def convert_single_elem_array_to_scalar(obj, keepdims=False):
143172
"""
144173
Convert array with single element to scalar

dpnp/dpnp_iface_statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
292292
pass
293293
else:
294294
if x1_desc.dtype != dpnp.float64:
295-
x1_desc = dpnp_astype(x1_desc, dpnp.float64)
295+
x1_desc = dpnp.get_dpnp_descriptor(dpnp.astype(x1, dpnp.float64))
296296

297297
return dpnp_cov(x1_desc).get_pyobj()
298298

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def cholesky(input):
9494
pass
9595
else:
9696
if input.dtype == dpnp.int32 or input.dtype == dpnp.int64:
97-
input_ = dpnp_astype(x1_desc, dpnp.float64)
97+
# TODO memory copy. needs to move into DPNPC
98+
input_ = dpnp.get_dpnp_descriptor(dpnp.astype(input, dpnp.float64))
9899
else:
99100
input_ = x1_desc
100101
return dpnp_cholesky(input_).get_pyobj()

tests/test_random.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def test_shuffle(self, dtype):
946946
input_x = dpnp.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], dtype=dtype)
947947
dpnp.random.seed(seed)
948948
dpnp.random.shuffle(input_x_int64) # inplace
949-
desired_x = input_x_int64.astype(dtype)
949+
desired_x = dpnp.astype(input_x_int64, dtype)
950950
dpnp.random.seed(seed)
951951
dpnp.random.shuffle(input_x) # inplace
952952
actual_x = input_x
@@ -965,21 +965,21 @@ def test_no_miss_numbers(self, dtype):
965965
assert_array_equal(actual_x, desired_x)
966966

967967
@pytest.mark.parametrize("conv", [lambda x: dpnp.array([]),
968-
lambda x: dpnp.asarray(x).astype(dpnp.int8),
969-
lambda x: dpnp.asarray(x).astype(dpnp.float32),
968+
lambda x: dpnp.astype(asarray(x), dpnp.int8),
969+
lambda x: dpnp.astype(asarray(x), dpnp.float32),
970970
# lambda x: dpnp.asarray(x).astype(dpnp.complex64),
971-
lambda x: dpnp.asarray(x).astype(object),
971+
lambda x: dpnp.astype(asarray(x), object),
972972
lambda x: dpnp.asarray([[i, i] for i in x]),
973973
lambda x: dpnp.vstack([x, x]).T,
974974
lambda x: (dpnp.asarray([(i, i) for i in x], [
975975
("a", int), ("b", int)]).view(dpnp.recarray)),
976976
lambda x: dpnp.asarray([(i, i) for i in x],
977977
[("a", object), ("b", dpnp.int32)])],
978978
ids=['lambda x: dpnp.array([])',
979-
'lambda x: dpnp.asarray(x).astype(dpnp.int8)',
980-
'lambda x: dpnp.asarray(x).astype(dpnp.float32)',
979+
'lambda x: dpnp.astype(asarray(x), dpnp.int8)',
980+
'lambda x: dpnp.astype(asarray(x), dpnp.float32)',
981981
# 'lambda x: dpnp.asarray(x).astype(dpnp.complex64)',
982-
'lambda x: dpnp.asarray(x).astype(object)',
982+
'lambda x: dpnp.astype(asarray(x), object)',
983983
'lambda x: dpnp.asarray([[i, i] for i in x])',
984984
'lambda x: dpnp.vstack([x, x]).T',
985985
'lambda x: (dpnp.asarray([(i, i) for i in x], ['\

0 commit comments

Comments
 (0)