File tree Expand file tree Collapse file tree 22 files changed +136
-118
lines changed
dpnp/backend/extensions/lapack Expand file tree Collapse file tree 22 files changed +136
-118
lines changed Original file line number Diff line number Diff line change @@ -76,8 +76,8 @@ endif()
7676
7777set_target_properties (${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON )
7878
79- target_include_directories (${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} /../../include )
80- target_include_directories (${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} /../../src )
79+ target_include_directories (${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} /../../)
80+ target_include_directories (${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} /../common )
8181
8282target_include_directories (${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR} )
8383target_include_directories (${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR} )
Original file line number Diff line number Diff line change 3030// dpctl tensor headers
3131#include " utils/memory_overlap.hpp"
3232#include " utils/output_validation.hpp"
33- #include " utils/type_dispatch.hpp"
3433
3534namespace dpnp ::extensions::lapack::evd
3635{
37- namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3836namespace py = pybind11;
3937
40- template <typename dispatchT,
41- template <typename fnT, typename T, typename RealT>
42- typename factoryT>
43- void init_evd_dispatch_table (
44- dispatchT evd_dispatch_table[][dpctl_td_ns::num_types])
45- {
46- dpctl_td_ns::DispatchTableBuilder<dispatchT, factoryT,
47- dpctl_td_ns::num_types>
48- contig;
49- contig.populate_dispatch_table (evd_dispatch_table);
50- }
51-
5238inline void common_evd_checks (sycl::queue &exec_q,
5339 const dpctl::tensor::usm_ndarray &eig_vecs,
5440 const dpctl::tensor::usm_ndarray &eig_vals,
Original file line number Diff line number Diff line change 2727
2828#include < pybind11/pybind11.h>
2929
30+ // utils extension header
31+ #include " ext/common.hpp"
32+
3033// dpctl tensor headers
3134#include " utils/memory_overlap.hpp"
3235#include " utils/sycl_alloc_utils.hpp"
3538#include " geqrf.hpp"
3639#include " types_matrix.hpp"
3740
38- #include " dpnp_utils.hpp"
39-
4041namespace dpnp ::extensions::lapack
4142{
4243namespace mkl_lapack = oneapi::mkl::lapack;
4344namespace py = pybind11;
4445namespace type_utils = dpctl::tensor::type_utils;
4546
47+ using ext::common::init_dispatch_vector;
48+
4649typedef sycl::event (*geqrf_impl_fn_ptr_t )(sycl::queue &,
4750 const std::int64_t ,
4851 const std::int64_t ,
@@ -250,9 +253,7 @@ struct GeqrfContigFactory
250253
251254void init_geqrf_dispatch_vector (void )
252255{
253- dpctl_td_ns::DispatchVectorBuilder<geqrf_impl_fn_ptr_t , GeqrfContigFactory,
254- dpctl_td_ns::num_types>
255- contig;
256- contig.populate_dispatch_vector (geqrf_dispatch_vector);
256+ init_dispatch_vector<geqrf_impl_fn_ptr_t , GeqrfContigFactory>(
257+ geqrf_dispatch_vector);
257258}
258259} // namespace dpnp::extensions::lapack
Original file line number Diff line number Diff line change 2727
2828#include < pybind11/pybind11.h>
2929
30+ // utils extension header
31+ #include " ext/common.hpp"
32+
3033// dpctl tensor headers
3134#include " utils/memory_overlap.hpp"
3235#include " utils/sycl_alloc_utils.hpp"
3538#include " geqrf.hpp"
3639#include " types_matrix.hpp"
3740
38- #include " dpnp_utils.hpp"
39-
4041namespace dpnp ::extensions::lapack
4142{
4243namespace mkl_lapack = oneapi::mkl::lapack;
4344namespace py = pybind11;
4445namespace type_utils = dpctl::tensor::type_utils;
4546
47+ using ext::common::init_dispatch_vector;
48+
4649typedef sycl::event (*geqrf_batch_impl_fn_ptr_t )(
4750 sycl::queue &,
4851 std::int64_t ,
@@ -260,10 +263,7 @@ struct GeqrfBatchContigFactory
260263
261264void init_geqrf_batch_dispatch_vector (void )
262265{
263- dpctl_td_ns::DispatchVectorBuilder<geqrf_batch_impl_fn_ptr_t ,
264- GeqrfBatchContigFactory,
265- dpctl_td_ns::num_types>
266- contig;
267- contig.populate_dispatch_vector (geqrf_batch_dispatch_vector);
266+ init_dispatch_vector<geqrf_batch_impl_fn_ptr_t , GeqrfBatchContigFactory>(
267+ geqrf_batch_dispatch_vector);
268268}
269269} // namespace dpnp::extensions::lapack
Original file line number Diff line number Diff line change 2828
2929#include < pybind11/pybind11.h>
3030
31+ // utils extension header
32+ #include " ext/common.hpp"
33+
3134// dpctl tensor headers
3235#include " utils/type_utils.hpp"
3336
@@ -43,6 +46,7 @@ namespace py = pybind11;
4346namespace type_utils = dpctl::tensor::type_utils;
4447
4548using dpctl::tensor::alloc_utils::sycl_free_noexcept;
49+ using ext::common::init_dispatch_vector;
4650
4751typedef sycl::event (*gesv_impl_fn_ptr_t )(sycl::queue &,
4852 const std::int64_t ,
@@ -284,9 +288,7 @@ struct GesvContigFactory
284288
285289void init_gesv_dispatch_vector (void )
286290{
287- dpctl_td_ns::DispatchVectorBuilder<gesv_impl_fn_ptr_t , GesvContigFactory,
288- dpctl_td_ns::num_types>
289- contig;
290- contig.populate_dispatch_vector (gesv_dispatch_vector);
291+ init_dispatch_vector<gesv_impl_fn_ptr_t , GesvContigFactory>(
292+ gesv_dispatch_vector);
291293}
292294} // namespace dpnp::extensions::lapack
Original file line number Diff line number Diff line change 2828
2929#include < pybind11/pybind11.h>
3030
31+ // utils extension header
32+ #include " ext/common.hpp"
33+
3134// dpctl tensor headers
3235#include " utils/type_utils.hpp"
3336
@@ -43,6 +46,7 @@ namespace py = pybind11;
4346namespace type_utils = dpctl::tensor::type_utils;
4447
4548using dpctl::tensor::alloc_utils::sycl_free_noexcept;
49+ using ext::common::init_dispatch_vector;
4650
4751typedef sycl::event (*gesv_batch_impl_fn_ptr_t )(
4852 sycl::queue &,
@@ -425,10 +429,7 @@ struct GesvBatchContigFactory
425429
426430void init_gesv_batch_dispatch_vector (void )
427431{
428- dpctl_td_ns::DispatchVectorBuilder<gesv_batch_impl_fn_ptr_t ,
429- GesvBatchContigFactory,
430- dpctl_td_ns::num_types>
431- contig;
432- contig.populate_dispatch_vector (gesv_batch_dispatch_vector);
432+ init_dispatch_vector<gesv_batch_impl_fn_ptr_t , GesvBatchContigFactory>(
433+ gesv_batch_dispatch_vector);
433434}
434435} // namespace dpnp::extensions::lapack
Original file line number Diff line number Diff line change 2727
2828#include < pybind11/pybind11.h>
2929
30+ // utils extension header
31+ #include " ext/common.hpp"
32+
3033// dpctl tensor headers
3134#include " utils/type_utils.hpp"
3235
@@ -41,6 +44,8 @@ namespace mkl_lapack = oneapi::mkl::lapack;
4144namespace py = pybind11;
4245namespace type_utils = dpctl::tensor::type_utils;
4346
47+ using ext::common::init_dispatch_table;
48+
4449typedef sycl::event (*gesvd_impl_fn_ptr_t )(sycl::queue &,
4550 const oneapi::mkl::jobsvd,
4651 const oneapi::mkl::jobsvd,
@@ -227,9 +232,7 @@ struct GesvdContigFactory
227232
228233void init_gesvd_dispatch_table (void )
229234{
230- dpctl_td_ns::DispatchTableBuilder<gesvd_impl_fn_ptr_t , GesvdContigFactory,
231- dpctl_td_ns::num_types>
232- contig;
233- contig.populate_dispatch_table (gesvd_dispatch_table);
235+ init_dispatch_table<gesvd_impl_fn_ptr_t , GesvdContigFactory>(
236+ gesvd_dispatch_table);
234237}
235238} // namespace dpnp::extensions::lapack
Original file line number Diff line number Diff line change 2727
2828#include < pybind11/pybind11.h>
2929
30+ // utils extension header
31+ #include " ext/common.hpp"
32+
3033// dpctl tensor headers
3134#include " utils/type_utils.hpp"
3235
@@ -41,6 +44,8 @@ namespace mkl_lapack = oneapi::mkl::lapack;
4144namespace py = pybind11;
4245namespace type_utils = dpctl::tensor::type_utils;
4346
47+ using ext::common::init_dispatch_table;
48+
4449typedef sycl::event (*gesvd_batch_impl_fn_ptr_t )(
4550 sycl::queue &,
4651 const oneapi::mkl::jobsvd,
@@ -296,10 +301,7 @@ struct GesvdBatchContigFactory
296301
297302void init_gesvd_batch_dispatch_table (void )
298303{
299- dpctl_td_ns::DispatchTableBuilder<gesvd_batch_impl_fn_ptr_t ,
300- GesvdBatchContigFactory,
301- dpctl_td_ns::num_types>
302- contig;
303- contig.populate_dispatch_table (gesvd_batch_dispatch_table);
304+ init_dispatch_table<gesvd_batch_impl_fn_ptr_t , GesvdBatchContigFactory>(
305+ gesvd_batch_dispatch_table);
304306}
305307} // namespace dpnp::extensions::lapack
Original file line number Diff line number Diff line change 2727
2828#include < pybind11/pybind11.h>
2929
30+ // utils extension header
31+ #include " ext/common.hpp"
32+
3033// dpctl tensor headers
3134#include " utils/memory_overlap.hpp"
3235#include " utils/sycl_alloc_utils.hpp"
3538#include " getrf.hpp"
3639#include " types_matrix.hpp"
3740
38- #include " dpnp_utils.hpp"
39-
4041namespace dpnp ::extensions::lapack
4142{
4243namespace mkl_lapack = oneapi::mkl::lapack;
4344namespace py = pybind11;
4445namespace type_utils = dpctl::tensor::type_utils;
4546
47+ using ext::common::init_dispatch_vector;
48+
4649typedef sycl::event (*getrf_impl_fn_ptr_t )(sycl::queue &,
4750 const std::int64_t ,
4851 const std::int64_t ,
@@ -262,9 +265,7 @@ struct GetrfContigFactory
262265
263266void init_getrf_dispatch_vector (void )
264267{
265- dpctl_td_ns::DispatchVectorBuilder<getrf_impl_fn_ptr_t , GetrfContigFactory,
266- dpctl_td_ns::num_types>
267- contig;
268- contig.populate_dispatch_vector (getrf_dispatch_vector);
268+ init_dispatch_vector<getrf_impl_fn_ptr_t , GetrfContigFactory>(
269+ getrf_dispatch_vector);
269270}
270271} // namespace dpnp::extensions::lapack
Original file line number Diff line number Diff line change 2727
2828#include < pybind11/pybind11.h>
2929
30+ // utils extension header
31+ #include " ext/common.hpp"
32+
3033// dpctl tensor headers
3134#include " utils/memory_overlap.hpp"
3235#include " utils/sycl_alloc_utils.hpp"
3538#include " getrf.hpp"
3639#include " types_matrix.hpp"
3740
38- #include " dpnp_utils.hpp"
39-
4041namespace dpnp ::extensions::lapack
4142{
4243namespace mkl_lapack = oneapi::mkl::lapack;
4344namespace py = pybind11;
4445namespace type_utils = dpctl::tensor::type_utils;
4546
47+ using ext::common::init_dispatch_vector;
48+
4649typedef sycl::event (*getrf_batch_impl_fn_ptr_t )(
4750 sycl::queue &,
4851 std::int64_t ,
@@ -295,10 +298,7 @@ struct GetrfBatchContigFactory
295298
296299void init_getrf_batch_dispatch_vector (void )
297300{
298- dpctl_td_ns::DispatchVectorBuilder<getrf_batch_impl_fn_ptr_t ,
299- GetrfBatchContigFactory,
300- dpctl_td_ns::num_types>
301- contig;
302- contig.populate_dispatch_vector (getrf_batch_dispatch_vector);
301+ init_dispatch_vector<getrf_batch_impl_fn_ptr_t , GetrfBatchContigFactory>(
302+ getrf_batch_dispatch_vector);
303303}
304304} // namespace dpnp::extensions::lapack
You can’t perform that action at this time.
0 commit comments