Skip to content

Commit 6f02cb2

Browse files
committed
Add implementation of dpnp.ndarray.view method
1 parent d15d395 commit 6f02cb2

File tree

1 file changed

+64
-1
lines changed

1 file changed

+64
-1
lines changed

dpnp/dpnp_array.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# *****************************************************************************
2626

2727
import dpctl.tensor as dpt
28+
import dpctl.tensor._type_utils as dtu
2829
from dpctl.tensor._numpy_helper import AxisError
2930

3031
import dpnp
@@ -1979,5 +1980,67 @@ def var(
19791980
correction=correction,
19801981
)
19811982

1983+
def view(self, dtype=None):
1984+
"""TBD"""
19821985

1983-
# 'view'
1986+
old_sh = self.shape
1987+
old_strides = self.strides
1988+
1989+
if dtype is None:
1990+
return dpnp_array(old_sh, buffer=self, strides=old_strides)
1991+
1992+
new_dt = dpnp.dtype(dtype)
1993+
new_dt = dtu._to_device_supported_dtype(new_dt, self.sycl_device)
1994+
1995+
new_itemsz = new_dt.itemsize
1996+
old_itemsz = self.dtype.itemsize
1997+
if new_itemsz == old_itemsz:
1998+
return dpnp_array(
1999+
old_sh, dtype=new_dt, buffer=self, strides=old_strides
2000+
)
2001+
2002+
ndim = self.ndim
2003+
if ndim == 0:
2004+
raise ValueError(
2005+
"Changing the dtype of a 0d array is only supported "
2006+
"if the itemsize is unchanged"
2007+
)
2008+
2009+
# resize on last axis only
2010+
axis = ndim - 1
2011+
if old_sh[axis] != 1 and self.size != 0 and old_strides[axis] != 1:
2012+
raise ValueError(
2013+
"To change to a dtype of a different size, "
2014+
"the last axis must be contiguous"
2015+
)
2016+
2017+
# normalize strides whenever itemsize changes
2018+
if old_itemsz > new_itemsz:
2019+
new_strides = list(
2020+
el * (old_itemsz // new_itemsz) for el in old_strides
2021+
)
2022+
else:
2023+
new_strides = list(
2024+
el // (new_itemsz // old_itemsz) for el in old_strides
2025+
)
2026+
new_strides[axis] = 1
2027+
new_strides = tuple(new_strides)
2028+
2029+
new_dim = old_sh[axis] * old_itemsz
2030+
if new_dim % new_itemsz != 0:
2031+
raise ValueError(
2032+
"When changing to a larger dtype, its size must be a divisor "
2033+
"of the total size in bytes of the last axis of the array"
2034+
)
2035+
2036+
# normalize shape whenever itemsize changes
2037+
new_sh = list(old_sh)
2038+
new_sh[axis] = new_dim // new_itemsz
2039+
new_sh = tuple(new_sh)
2040+
2041+
return dpnp_array(
2042+
new_sh,
2043+
dtype=new_dt,
2044+
buffer=self,
2045+
strides=new_strides,
2046+
)

0 commit comments

Comments
 (0)