|
25 | 25 | # ***************************************************************************** |
26 | 26 |
|
27 | 27 | import dpctl.tensor as dpt |
| 28 | +import dpctl.tensor._type_utils as dtu |
28 | 29 | from dpctl.tensor._numpy_helper import AxisError |
29 | 30 |
|
30 | 31 | import dpnp |
@@ -1979,5 +1980,67 @@ def var( |
1979 | 1980 | correction=correction, |
1980 | 1981 | ) |
1981 | 1982 |
|
| 1983 | + def view(self, dtype=None): |
| 1984 | + """TBD""" |
1982 | 1985 |
|
1983 | | -# 'view' |
| 1986 | + old_sh = self.shape |
| 1987 | + old_strides = self.strides |
| 1988 | + |
| 1989 | + if dtype is None: |
| 1990 | + return dpnp_array(old_sh, buffer=self, strides=old_strides) |
| 1991 | + |
| 1992 | + new_dt = dpnp.dtype(dtype) |
| 1993 | + new_dt = dtu._to_device_supported_dtype(new_dt, self.sycl_device) |
| 1994 | + |
| 1995 | + new_itemsz = new_dt.itemsize |
| 1996 | + old_itemsz = self.dtype.itemsize |
| 1997 | + if new_itemsz == old_itemsz: |
| 1998 | + return dpnp_array( |
| 1999 | + old_sh, dtype=new_dt, buffer=self, strides=old_strides |
| 2000 | + ) |
| 2001 | + |
| 2002 | + ndim = self.ndim |
| 2003 | + if ndim == 0: |
| 2004 | + raise ValueError( |
| 2005 | + "Changing the dtype of a 0d array is only supported " |
| 2006 | + "if the itemsize is unchanged" |
| 2007 | + ) |
| 2008 | + |
| 2009 | + # resize on last axis only |
| 2010 | + axis = ndim - 1 |
| 2011 | + if old_sh[axis] != 1 and self.size != 0 and old_strides[axis] != 1: |
| 2012 | + raise ValueError( |
| 2013 | + "To change to a dtype of a different size, " |
| 2014 | + "the last axis must be contiguous" |
| 2015 | + ) |
| 2016 | + |
| 2017 | + # normalize strides whenever itemsize changes |
| 2018 | + if old_itemsz > new_itemsz: |
| 2019 | + new_strides = list( |
| 2020 | + el * (old_itemsz // new_itemsz) for el in old_strides |
| 2021 | + ) |
| 2022 | + else: |
| 2023 | + new_strides = list( |
| 2024 | + el // (new_itemsz // old_itemsz) for el in old_strides |
| 2025 | + ) |
| 2026 | + new_strides[axis] = 1 |
| 2027 | + new_strides = tuple(new_strides) |
| 2028 | + |
| 2029 | + new_dim = old_sh[axis] * old_itemsz |
| 2030 | + if new_dim % new_itemsz != 0: |
| 2031 | + raise ValueError( |
| 2032 | + "When changing to a larger dtype, its size must be a divisor " |
| 2033 | + "of the total size in bytes of the last axis of the array" |
| 2034 | + ) |
| 2035 | + |
| 2036 | + # normalize shape whenever itemsize changes |
| 2037 | + new_sh = list(old_sh) |
| 2038 | + new_sh[axis] = new_dim // new_itemsz |
| 2039 | + new_sh = tuple(new_sh) |
| 2040 | + |
| 2041 | + return dpnp_array( |
| 2042 | + new_sh, |
| 2043 | + dtype=new_dt, |
| 2044 | + buffer=self, |
| 2045 | + strides=new_strides, |
| 2046 | + ) |
0 commit comments