Skip to content

Commit 3a7d4f3

Browse files
committed
Rewrote eye constructor to use a dedicated kernel
1 parent d3e41ec commit 3a7d4f3

File tree

2 files changed

+173
-23
lines changed

2 files changed

+173
-23
lines changed

dpctl/tensor/_ctors.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,37 +1084,34 @@ def eye(
10841084
`dpctl.SyclQueue()` is used for allocation and copying.
10851085
Default: `None`.
10861086
"""
1087-
if n_cols is None:
1088-
n_cols = n_rows
1089-
# allocate a 1D array of zeros, length equal to n_cols * n_rows
10901087
if not isinstance(order, str) or len(order) == 0 or order[0] not in "CcFf":
10911088
raise ValueError(
10921089
"Unrecognized order keyword value, expecting 'F' or 'C'."
10931090
)
10941091
else:
10951092
order = order[0].upper()
10961093
n_rows = operator.index(n_rows)
1097-
n_cols = operator.index(n_cols)
1094+
n_cols = n_rows if n_cols is None else operator.index(n_cols)
10981095
k = operator.index(k)
1099-
x = zeros(
1100-
(n_rows * n_cols,),
1096+
if k >= n_cols or -k >= n_rows:
1097+
return dpt.zeros(
1098+
(n_rows, n_cols),
1099+
dtype=dtype,
1100+
order=order,
1101+
device=device,
1102+
usm_type=usm_type,
1103+
sycl_queue=sycl_queue,
1104+
)
1105+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1106+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1107+
res = dpt.usm_ndarray(
1108+
(n_rows, n_cols),
11011109
dtype=dtype,
1110+
buffer=usm_type,
11021111
order=order,
1103-
device=device,
1104-
usm_type=usm_type,
1105-
sycl_queue=sycl_queue,
1112+
buffer_ctor_kwargs={"queue": sycl_queue},
11061113
)
1107-
if k > -n_rows and k < n_cols:
1108-
# find the length of the diagonal
1109-
L = min(n_cols, n_rows, n_cols - k, n_rows + k)
1110-
# i is the first index of diagonal, j is the last, s is the step size
1111-
if order == "C":
1112-
s = n_cols + 1
1113-
i = k if k >= 0 else n_cols * -k
1114-
else:
1115-
s = n_rows + 1
1116-
i = n_rows * k if k > 0 else -k
1117-
j = i + s * (L - 1) + 1
1118-
x[i:j:s] = 1
1119-
# copy=False ensures no wasted memory copying the array
1120-
return dpt.reshape(x, (n_rows, n_cols), order=order, copy=False)
1114+
if n_rows != 0 and n_cols != 0:
1115+
hev, _ = ti._eye(k, dst=res, sycl_queue=sycl_queue)
1116+
hev.wait()
1117+
return res

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ template <typename srcT, typename dstT, int nd> class copy_cast_spec_kernel;
4545
template <typename Ty> class copy_for_reshape_generic_kernel;
4646
template <typename Ty> class linear_sequence_step_kernel;
4747
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
48+
template <typename Ty> class eye_kernel;
4849

4950
static dpctl::tensor::detail::usm_ndarray_types array_types;
5051

@@ -1742,6 +1743,144 @@ usm_ndarray_full(py::object py_value,
17421743
}
17431744
}
17441745

1746+
/* ================ Eye ================== */
1747+
1748+
typedef sycl::event (*eye_fn_ptr_t)(sycl::queue,
1749+
size_t nelems, // num_elements
1750+
py::ssize_t start,
1751+
py::ssize_t end,
1752+
py::ssize_t step,
1753+
char *, // dst_data_ptr
1754+
const std::vector<sycl::event> &);
1755+
1756+
static eye_fn_ptr_t eye_dispatch_vector[_ns::num_types];
1757+
1758+
template <typename Ty> class EyeFunctor
1759+
{
1760+
private:
1761+
Ty *p = nullptr;
1762+
py::ssize_t start_v;
1763+
py::ssize_t end_v;
1764+
py::ssize_t step_v;
1765+
1766+
public:
1767+
EyeFunctor(char *dst_p,
1768+
const py::ssize_t v0,
1769+
const py::ssize_t v1,
1770+
const py::ssize_t dv)
1771+
: p(reinterpret_cast<Ty *>(dst_p)), start_v(v0), end_v(v1), step_v(dv)
1772+
{
1773+
}
1774+
1775+
void operator()(sycl::id<1> wiid) const
1776+
{
1777+
Ty set_v = 0;
1778+
py::ssize_t i = static_cast<py::ssize_t>(wiid.get(0));
1779+
if (i >= start_v and i <= end_v) {
1780+
if ((i - start_v) % step_v == 0) {
1781+
set_v = 1;
1782+
}
1783+
}
1784+
p[i] = set_v;
1785+
}
1786+
};
1787+
1788+
template <typename Ty>
1789+
sycl::event eye_impl(sycl::queue exec_q,
1790+
size_t nelems,
1791+
const py::ssize_t start,
1792+
const py::ssize_t end,
1793+
const py::ssize_t step,
1794+
char *array_data,
1795+
const std::vector<sycl::event> &depends)
1796+
{
1797+
sycl::event eye_event = exec_q.submit([&](sycl::handler &cgh) {
1798+
cgh.depends_on(depends);
1799+
cgh.parallel_for<eye_kernel<Ty>>(
1800+
sycl::range<1>{nelems},
1801+
EyeFunctor<Ty>(array_data, start, end, step));
1802+
});
1803+
1804+
return eye_event;
1805+
}
1806+
1807+
template <typename fnT, typename Ty> struct EyeFactory
1808+
{
1809+
fnT get()
1810+
{
1811+
fnT f = eye_impl<Ty>;
1812+
return f;
1813+
}
1814+
};
1815+
1816+
std::pair<sycl::event, sycl::event>
1817+
eye(py::ssize_t k,
1818+
dpctl::tensor::usm_ndarray dst,
1819+
sycl::queue exec_q,
1820+
const std::vector<sycl::event> &depends = {})
1821+
{
1822+
// dst must be 2D
1823+
1824+
if (dst.get_ndim() != 2) {
1825+
throw py::value_error(
1826+
"usm_ndarray_eye: Expecting 2D array to populate");
1827+
}
1828+
1829+
sycl::queue dst_q = dst.get_queue();
1830+
if (dst_q != exec_q && dst_q.get_context() != exec_q.get_context()) {
1831+
throw py::value_error(
1832+
"Execution queue context is not the same as allocation context");
1833+
}
1834+
1835+
int dst_typenum = dst.get_typenum();
1836+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
1837+
1838+
const py::ssize_t nelem = dst.get_size();
1839+
const py::ssize_t rows = dst.get_shape(0);
1840+
const py::ssize_t cols = dst.get_shape(1);
1841+
if (rows == 0 || cols == 0) {
1842+
// nothing to do
1843+
return std::make_pair(sycl::event{}, sycl::event{});
1844+
}
1845+
1846+
bool is_dst_c_contig = ((dst.get_flags() & USM_ARRAY_C_CONTIGUOUS) != 0);
1847+
bool is_dst_f_contig = ((dst.get_flags() & USM_ARRAY_F_CONTIGUOUS) != 0);
1848+
if (!is_dst_c_contig && !is_dst_f_contig) {
1849+
throw py::value_error("USM array is not contiguous");
1850+
}
1851+
1852+
py::ssize_t start;
1853+
if (is_dst_c_contig) {
1854+
start = (k < 0) ? -k * cols : k;
1855+
}
1856+
else {
1857+
start = (k < 0) ? -k : k * rows;
1858+
}
1859+
1860+
py::ssize_t step;
1861+
if (dst.get_strides_raw() == nullptr) {
1862+
step = (is_dst_c_contig) ? cols + 1 : rows + 1;
1863+
}
1864+
else {
1865+
const py::ssize_t *strides = dst.get_strides_raw();
1866+
step = strides[0] + strides[1];
1867+
}
1868+
1869+
const py::ssize_t length = std::min({rows, cols, rows + k, cols - k});
1870+
const py::ssize_t end = start + step * (length - 1) + 1;
1871+
1872+
char *dst_data = dst.get_data();
1873+
sycl::event eye_event;
1874+
1875+
auto fn = eye_dispatch_vector[dst_typeid];
1876+
1877+
eye_event = fn(exec_q, static_cast<size_t>(nelem), start, end, step,
1878+
dst_data, depends);
1879+
1880+
return std::make_pair(keep_args_alive(exec_q, {dst}, {eye_event}),
1881+
eye_event);
1882+
}
1883+
17451884
// populate dispatch tables
17461885
void init_copy_and_cast_dispatch_tables(void)
17471886
{
@@ -1796,6 +1935,10 @@ void init_copy_for_reshape_dispatch_vector(void)
17961935
dvb3;
17971936
dvb3.populate_dispatch_vector(full_contig_dispatch_vector);
17981937

1938+
DispatchVectorBuilder<eye_fn_ptr_t, EyeFactory, num_types>
1939+
dvb4;
1940+
dvb4.populate_dispatch_vector(eye_dispatch_vector);
1941+
17991942
return;
18001943
}
18011944

@@ -1901,6 +2044,16 @@ PYBIND11_MODULE(_tensor_impl, m)
19012044
py::arg("fill_value"), py::arg("dst"), py::arg("sycl_queue"),
19022045
py::arg("depends") = py::list());
19032046

2047+
m.def("_eye", &eye,
2048+
"Fills input 2D contiguous usm_ndarray `dst` with "
2049+
"zeros outside of the diagonal "
2050+
"specified by "
2051+
"the diagonal index `k` "
2052+
"which is filled with ones."
2053+
"Returns a tuple of events: (ht_event, comp_event)",
2054+
py::arg("k"), py::arg("dst"), py::arg("sycl_queue"),
2055+
py::arg("depends") = py::list());
2056+
19042057
m.def("default_device_fp_type", [](sycl::queue q) -> std::string {
19052058
return get_default_device_fp_type(q.get_device());
19062059
});

0 commit comments

Comments
 (0)