@@ -45,6 +45,7 @@ template <typename srcT, typename dstT, int nd> class copy_cast_spec_kernel;
45
45
template <typename Ty> class copy_for_reshape_generic_kernel ;
46
46
template <typename Ty> class linear_sequence_step_kernel ;
47
47
template <typename Ty, typename wTy> class linear_sequence_affine_kernel ;
48
+ template <typename Ty> class eye_kernel ;
48
49
49
50
static dpctl::tensor::detail::usm_ndarray_types array_types;
50
51
@@ -1742,6 +1743,144 @@ usm_ndarray_full(py::object py_value,
1742
1743
}
1743
1744
}
1744
1745
1746
+ /* ================ Eye ================== */
1747
+
1748
+ typedef sycl::event (*eye_fn_ptr_t )(sycl::queue,
1749
+ size_t nelems, // num_elements
1750
+ py::ssize_t start,
1751
+ py::ssize_t end,
1752
+ py::ssize_t step,
1753
+ char *, // dst_data_ptr
1754
+ const std::vector<sycl::event> &);
1755
+
1756
+ static eye_fn_ptr_t eye_dispatch_vector[_ns::num_types];
1757
+
1758
+ template <typename Ty> class EyeFunctor
1759
+ {
1760
+ private:
1761
+ Ty *p = nullptr ;
1762
+ py::ssize_t start_v;
1763
+ py::ssize_t end_v;
1764
+ py::ssize_t step_v;
1765
+
1766
+ public:
1767
+ EyeFunctor (char *dst_p,
1768
+ const py::ssize_t v0,
1769
+ const py::ssize_t v1,
1770
+ const py::ssize_t dv)
1771
+ : p(reinterpret_cast <Ty *>(dst_p)), start_v(v0), end_v(v1), step_v(dv)
1772
+ {
1773
+ }
1774
+
1775
+ void operator ()(sycl::id<1 > wiid) const
1776
+ {
1777
+ Ty set_v = 0 ;
1778
+ py::ssize_t i = static_cast <py::ssize_t >(wiid.get (0 ));
1779
+ if (i >= start_v and i <= end_v) {
1780
+ if ((i - start_v) % step_v == 0 ) {
1781
+ set_v = 1 ;
1782
+ }
1783
+ }
1784
+ p[i] = set_v;
1785
+ }
1786
+ };
1787
+
1788
+ template <typename Ty>
1789
+ sycl::event eye_impl (sycl::queue exec_q,
1790
+ size_t nelems,
1791
+ const py::ssize_t start,
1792
+ const py::ssize_t end,
1793
+ const py::ssize_t step,
1794
+ char *array_data,
1795
+ const std::vector<sycl::event> &depends)
1796
+ {
1797
+ sycl::event eye_event = exec_q.submit ([&](sycl::handler &cgh) {
1798
+ cgh.depends_on (depends);
1799
+ cgh.parallel_for <eye_kernel<Ty>>(
1800
+ sycl::range<1 >{nelems},
1801
+ EyeFunctor<Ty>(array_data, start, end, step));
1802
+ });
1803
+
1804
+ return eye_event;
1805
+ }
1806
+
1807
+ template <typename fnT, typename Ty> struct EyeFactory
1808
+ {
1809
+ fnT get ()
1810
+ {
1811
+ fnT f = eye_impl<Ty>;
1812
+ return f;
1813
+ }
1814
+ };
1815
+
1816
+ std::pair<sycl::event, sycl::event>
1817
+ eye (py::ssize_t k,
1818
+ dpctl::tensor::usm_ndarray dst,
1819
+ sycl::queue exec_q,
1820
+ const std::vector<sycl::event> &depends = {})
1821
+ {
1822
+ // dst must be 2D
1823
+
1824
+ if (dst.get_ndim () != 2 ) {
1825
+ throw py::value_error (
1826
+ " usm_ndarray_eye: Expecting 2D array to populate" );
1827
+ }
1828
+
1829
+ sycl::queue dst_q = dst.get_queue ();
1830
+ if (dst_q != exec_q && dst_q.get_context () != exec_q.get_context ()) {
1831
+ throw py::value_error (
1832
+ " Execution queue context is not the same as allocation context" );
1833
+ }
1834
+
1835
+ int dst_typenum = dst.get_typenum ();
1836
+ int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
1837
+
1838
+ const py::ssize_t nelem = dst.get_size ();
1839
+ const py::ssize_t rows = dst.get_shape (0 );
1840
+ const py::ssize_t cols = dst.get_shape (1 );
1841
+ if (rows == 0 || cols == 0 ) {
1842
+ // nothing to do
1843
+ return std::make_pair (sycl::event{}, sycl::event{});
1844
+ }
1845
+
1846
+ bool is_dst_c_contig = ((dst.get_flags () & USM_ARRAY_C_CONTIGUOUS) != 0 );
1847
+ bool is_dst_f_contig = ((dst.get_flags () & USM_ARRAY_F_CONTIGUOUS) != 0 );
1848
+ if (!is_dst_c_contig && !is_dst_f_contig) {
1849
+ throw py::value_error (" USM array is not contiguous" );
1850
+ }
1851
+
1852
+ py::ssize_t start;
1853
+ if (is_dst_c_contig) {
1854
+ start = (k < 0 ) ? -k * cols : k;
1855
+ }
1856
+ else {
1857
+ start = (k < 0 ) ? -k : k * rows;
1858
+ }
1859
+
1860
+ py::ssize_t step;
1861
+ if (dst.get_strides_raw () == nullptr ) {
1862
+ step = (is_dst_c_contig) ? cols + 1 : rows + 1 ;
1863
+ }
1864
+ else {
1865
+ const py::ssize_t *strides = dst.get_strides_raw ();
1866
+ step = strides[0 ] + strides[1 ];
1867
+ }
1868
+
1869
+ const py::ssize_t length = std::min ({rows, cols, rows + k, cols - k});
1870
+ const py::ssize_t end = start + step * (length - 1 ) + 1 ;
1871
+
1872
+ char *dst_data = dst.get_data ();
1873
+ sycl::event eye_event;
1874
+
1875
+ auto fn = eye_dispatch_vector[dst_typeid];
1876
+
1877
+ eye_event = fn (exec_q, static_cast <size_t >(nelem), start, end, step,
1878
+ dst_data, depends);
1879
+
1880
+ return std::make_pair (keep_args_alive (exec_q, {dst}, {eye_event}),
1881
+ eye_event);
1882
+ }
1883
+
1745
1884
// populate dispatch tables
1746
1885
void init_copy_and_cast_dispatch_tables (void )
1747
1886
{
@@ -1796,6 +1935,10 @@ void init_copy_for_reshape_dispatch_vector(void)
1796
1935
dvb3;
1797
1936
dvb3.populate_dispatch_vector (full_contig_dispatch_vector);
1798
1937
1938
+ DispatchVectorBuilder<eye_fn_ptr_t , EyeFactory, num_types>
1939
+ dvb4;
1940
+ dvb4.populate_dispatch_vector (eye_dispatch_vector);
1941
+
1799
1942
return ;
1800
1943
}
1801
1944
@@ -1901,6 +2044,16 @@ PYBIND11_MODULE(_tensor_impl, m)
1901
2044
py::arg (" fill_value" ), py::arg (" dst" ), py::arg (" sycl_queue" ),
1902
2045
py::arg (" depends" ) = py::list ());
1903
2046
2047
+ m.def (" _eye" , &eye,
2048
+ " Fills input 2D contiguous usm_ndarray `dst` with "
2049
+ " zeros outside of the diagonal "
2050
+ " specified by "
2051
+ " the diagonal index `k` "
2052
+ " which is filled with ones."
2053
+ " Returns a tuple of events: (ht_event, comp_event)" ,
2054
+ py::arg (" k" ), py::arg (" dst" ), py::arg (" sycl_queue" ),
2055
+ py::arg (" depends" ) = py::list ());
2056
+
1904
2057
m.def (" default_device_fp_type" , [](sycl::queue q) -> std::string {
1905
2058
return get_default_device_fp_type (q.get_device ());
1906
2059
});
0 commit comments