Skip to content

Commit e694d17

Browse files
committed
Formatting, documentation adjusted
1 parent 717e60c commit e694d17

File tree

2 files changed

+71
-58
lines changed

2 files changed

+71
-58
lines changed

dpctl/tensor/_ctors.py

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,7 @@ def linspace(
10361036
hev.wait()
10371037
return res
10381038

1039+
10391040
def eye(
10401041
n_rows,
10411042
n_cols=None,
@@ -1047,61 +1048,73 @@ def eye(
10471048
device=None,
10481049
usm_type="device",
10491050
sycl_queue=None
1050-
):
1051-
"""
1052-
eye(n_rows, n_cols = None, /, *, k = 0, dtype = None, \
1053-
device = None, usm_type="device", sycl_queue=None) -> usm_ndarray
1051+
):
1052+
"""
1053+
eye(n_rows, n_cols = None, /, *, k = 0, dtype = None, \
1054+
device = None, usm_type="device", sycl_queue=None) -> usm_ndarray
10541055
1055-
Creates `usm_ndarray` where the `k`th diagonal elements are one and others are zero.
1056+
Creates `usm_ndarray` with ones on the `k`th diagonal.
10561057
1057-
Args:
1058-
n_rows: number of rows in the output array.
1059-
n_cols (optional): number of columns in the output array. If None,
1060-
n_cols = n_rows. Default: `None`.
1061-
k: index of the diagonal, with 0 as the main diagonal. A positive value of k
1062-
is an upper diagonal, a negative value is a low diagonal. Default: `0`.
1063-
dtype (optional): data type of the array. Can be typestring,
1064-
a `numpy.dtype` object, `numpy` char string, or a numpy
1065-
scalar type. Default: None
1066-
order ("C" or F"): memory layout for the array. Default: "C"
1067-
device (optional): array API concept of device where the output array
1068-
is created. `device` can be `None`, a oneAPI filter selector string,
1069-
an instance of :class:`dpctl.SyclDevice` corresponding to a
1070-
non-partitioned SYCL device, an instance of
1071-
:class:`dpctl.SyclQueue`, or a `Device` object returnedby
1072-
`dpctl.tensor.usm_array.device`. Default: `None`.
1073-
usm_type ("device"|"shared"|"host", optional): The type of SYCL USM
1074-
allocation for the output array. Default: `"device"`.
1075-
sycl_queue (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
1076-
for output array allocation and copying. `sycl_queue` and `device`
1077-
are exclusive keywords, i.e. use one or another. If both are
1078-
specified, a `TypeError` is raised unless both imply the same
1079-
underlying SYCL queue to be used. If both are `None`, the
1080-
`dpctl.SyclQueue()` is used for allocation and copying.
1081-
Default: `None`.
1082-
"""
1083-
if n_cols is None:
1084-
n_cols = n_rows
1085-
#allocate a 1D array of zeros, length equal to n_cols * n_rows
1086-
k_dt = type(k)
1087-
if not np.issubdtype(k_dt, np.integer):
1088-
raise TypeError(
1089-
"k keyword must be an integer, got {type}".format(type=k_dt)
1090-
)
1091-
x = zeros((n_rows * n_cols,), dtype=dtype, order=order, device=device, usm_type=usm_type, sycl_queue=sycl_queue)
1092-
if k > -n_rows and k < n_cols:
1093-
#find the length of the diagonal
1094-
l = min(n_cols, n_rows, n_cols-k, n_rows+k)
1095-
#i is the first index of the diagonal in 1D index space, j is the last, s is the step size
1096-
if order == "C":
1097-
s = n_cols+1
1098-
i = k if k >= 0 else n_cols*-k
1099-
else:
1100-
s = n_rows+1
1101-
i = n_rows*k if k > 0 else -k
1102-
#last index + 1 prevents slice from excluding the last element
1103-
j = i+((l-1)*s)+1
1104-
x[i:j:s] = 1
1105-
#copy=False ensures no wasted memory copying the array
1106-
#and as the order parameter is the same, a copy should never be necessary
1107-
return dpt.reshape(x, (n_rows, n_cols), order=order, copy=False)
1058+
Args:
1059+
n_rows: number of rows in the output array.
1060+
n_cols (optional): number of columns in the output array. If None,
1061+
n_cols = n_rows. Default: `None`.
1062+
k: index of the diagonal, with 0 as the main diagonal.
1063+
A positive value of k is a superdiagonal, a negative value
1064+
is a subdiagonal.
1065+
Raises `TypeError` if k is not an integer.
1066+
Default: `0`.
1067+
dtype (optional): data type of the array. Can be typestring,
1068+
a `numpy.dtype` object, `numpy` char string, or a numpy
1069+
scalar type. Default: None
1070+
order ("C" or F"): memory layout for the array. Default: "C"
1071+
device (optional): array API concept of device where the output array
1072+
is created. `device` can be `None`, a oneAPI filter selector string,
1073+
an instance of :class:`dpctl.SyclDevice` corresponding to a
1074+
non-partitioned SYCL device, an instance of
1075+
:class:`dpctl.SyclQueue`, or a `Device` object returnedby
1076+
`dpctl.tensor.usm_array.device`. Default: `None`.
1077+
usm_type ("device"|"shared"|"host", optional): The type of SYCL USM
1078+
allocation for the output array. Default: `"device"`.
1079+
sycl_queue (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
1080+
for output array allocation and copying. `sycl_queue` and `device`
1081+
are exclusive keywords, i.e. use one or another. If both are
1082+
specified, a `TypeError` is raised unless both imply the same
1083+
underlying SYCL queue to be used. If both are `None`, the
1084+
`dpctl.SyclQueue()` is used for allocation and copying.
1085+
Default: `None`.
1086+
"""
1087+
if n_cols is None:
1088+
n_cols = n_rows
1089+
# allocate a 1D array of zeros, length equal to n_cols * n_rows
1090+
k_dt = type(k)
1091+
if not np.issubdtype(k_dt, np.integer):
1092+
raise TypeError(
1093+
"k keyword must be an integer, got {type}".format(type=k_dt)
1094+
)
1095+
x = zeros(
1096+
(n_rows * n_cols,),
1097+
dtype=dtype,
1098+
order=order,
1099+
device=device,
1100+
usm_type=usm_type,
1101+
sycl_queue=sycl_queue
1102+
)
1103+
if k > -n_rows and k < n_cols:
1104+
# find the length of the diagonal
1105+
L = min(
1106+
n_cols,
1107+
n_rows,
1108+
n_cols-k,
1109+
n_rows+k)
1110+
# i is the first index of diagonal, j is the last, s is the step size
1111+
if order == "C":
1112+
s = n_cols + 1
1113+
i = k if k >= 0 else n_cols*-k
1114+
else:
1115+
s = n_rows + 1
1116+
i = n_rows*k if k > 0 else -k
1117+
j = i + s*(L-1) + 1
1118+
x[i:j:s] = 1
1119+
# copy=False ensures no wasted memory copying the array
1120+
return dpt.reshape(x, (n_rows, n_cols), order=order, copy=False)

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import ctypes
1818
import numbers
19-
from typing import Type
2019

2120
import numpy as np
2221
import pytest
@@ -1292,7 +1291,8 @@ def test_common_arg_validation():
12921291
with pytest.raises(TypeError):
12931292
dpt.eye(4, k=1.2)
12941293

1295-
@pytest.mark.parametrize("shapes", [(0,), (1,), (7,), (6, 1), (3, 9), (10,5)])
1294+
1295+
@pytest.mark.parametrize("shapes", [(0,), (1,), (7,), (6, 1), (3, 9), (10, 5)])
12961296
@pytest.mark.parametrize("k", np.arange(-4, 5, 1))
12971297
@pytest.mark.parametrize("orders", ["C", "F"])
12981298
def test_eye(shapes, k, orders):

0 commit comments

Comments
 (0)