Skip to content

Commit c5b0639

Browse files
committed
Update lapack extension to use init_dispatch_vector() and init_dispatch_table() from common utils
1 parent efa5c45 commit c5b0639

22 files changed

+136
-118
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ endif()
7676

7777
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
7878

79-
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
80-
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
79+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
80+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
8181

8282
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
8383
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

dpnp/backend/extensions/lapack/evd_common_utils.hpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,11 @@
3030
// dpctl tensor headers
3131
#include "utils/memory_overlap.hpp"
3232
#include "utils/output_validation.hpp"
33-
#include "utils/type_dispatch.hpp"
3433

3534
namespace dpnp::extensions::lapack::evd
3635
{
37-
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3836
namespace py = pybind11;
3937

40-
template <typename dispatchT,
41-
template <typename fnT, typename T, typename RealT>
42-
typename factoryT>
43-
void init_evd_dispatch_table(
44-
dispatchT evd_dispatch_table[][dpctl_td_ns::num_types])
45-
{
46-
dpctl_td_ns::DispatchTableBuilder<dispatchT, factoryT,
47-
dpctl_td_ns::num_types>
48-
contig;
49-
contig.populate_dispatch_table(evd_dispatch_table);
50-
}
51-
5238
inline void common_evd_checks(sycl::queue &exec_q,
5339
const dpctl::tensor::usm_ndarray &eig_vecs,
5440
const dpctl::tensor::usm_ndarray &eig_vals,

dpnp/backend/extensions/lapack/geqrf.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30+
// utils extension header
31+
#include "ext/common.hpp"
32+
3033
// dpctl tensor headers
3134
#include "utils/memory_overlap.hpp"
3235
#include "utils/sycl_alloc_utils.hpp"
@@ -35,14 +38,14 @@
3538
#include "geqrf.hpp"
3639
#include "types_matrix.hpp"
3740

38-
#include "dpnp_utils.hpp"
39-
4041
namespace dpnp::extensions::lapack
4142
{
4243
namespace mkl_lapack = oneapi::mkl::lapack;
4344
namespace py = pybind11;
4445
namespace type_utils = dpctl::tensor::type_utils;
4546

47+
using ext::common::init_dispatch_vector;
48+
4649
typedef sycl::event (*geqrf_impl_fn_ptr_t)(sycl::queue &,
4750
const std::int64_t,
4851
const std::int64_t,
@@ -250,9 +253,7 @@ struct GeqrfContigFactory
250253

251254
void init_geqrf_dispatch_vector(void)
252255
{
253-
dpctl_td_ns::DispatchVectorBuilder<geqrf_impl_fn_ptr_t, GeqrfContigFactory,
254-
dpctl_td_ns::num_types>
255-
contig;
256-
contig.populate_dispatch_vector(geqrf_dispatch_vector);
256+
init_dispatch_vector<geqrf_impl_fn_ptr_t, GeqrfContigFactory>(
257+
geqrf_dispatch_vector);
257258
}
258259
} // namespace dpnp::extensions::lapack

dpnp/backend/extensions/lapack/geqrf_batch.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30+
// utils extension header
31+
#include "ext/common.hpp"
32+
3033
// dpctl tensor headers
3134
#include "utils/memory_overlap.hpp"
3235
#include "utils/sycl_alloc_utils.hpp"
@@ -35,14 +38,14 @@
3538
#include "geqrf.hpp"
3639
#include "types_matrix.hpp"
3740

38-
#include "dpnp_utils.hpp"
39-
4041
namespace dpnp::extensions::lapack
4142
{
4243
namespace mkl_lapack = oneapi::mkl::lapack;
4344
namespace py = pybind11;
4445
namespace type_utils = dpctl::tensor::type_utils;
4546

47+
using ext::common::init_dispatch_vector;
48+
4649
typedef sycl::event (*geqrf_batch_impl_fn_ptr_t)(
4750
sycl::queue &,
4851
std::int64_t,
@@ -260,10 +263,7 @@ struct GeqrfBatchContigFactory
260263

261264
void init_geqrf_batch_dispatch_vector(void)
262265
{
263-
dpctl_td_ns::DispatchVectorBuilder<geqrf_batch_impl_fn_ptr_t,
264-
GeqrfBatchContigFactory,
265-
dpctl_td_ns::num_types>
266-
contig;
267-
contig.populate_dispatch_vector(geqrf_batch_dispatch_vector);
266+
init_dispatch_vector<geqrf_batch_impl_fn_ptr_t, GeqrfBatchContigFactory>(
267+
geqrf_batch_dispatch_vector);
268268
}
269269
} // namespace dpnp::extensions::lapack

dpnp/backend/extensions/lapack/gesv.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
#include <pybind11/pybind11.h>
3030

31+
// utils extension header
32+
#include "ext/common.hpp"
33+
3134
// dpctl tensor headers
3235
#include "utils/type_utils.hpp"
3336

@@ -43,6 +46,7 @@ namespace py = pybind11;
4346
namespace type_utils = dpctl::tensor::type_utils;
4447

4548
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
49+
using ext::common::init_dispatch_vector;
4650

4751
typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue &,
4852
const std::int64_t,
@@ -284,9 +288,7 @@ struct GesvContigFactory
284288

285289
void init_gesv_dispatch_vector(void)
286290
{
287-
dpctl_td_ns::DispatchVectorBuilder<gesv_impl_fn_ptr_t, GesvContigFactory,
288-
dpctl_td_ns::num_types>
289-
contig;
290-
contig.populate_dispatch_vector(gesv_dispatch_vector);
291+
init_dispatch_vector<gesv_impl_fn_ptr_t, GesvContigFactory>(
292+
gesv_dispatch_vector);
291293
}
292294
} // namespace dpnp::extensions::lapack

dpnp/backend/extensions/lapack/gesv_batch.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
#include <pybind11/pybind11.h>
3030

31+
// utils extension header
32+
#include "ext/common.hpp"
33+
3134
// dpctl tensor headers
3235
#include "utils/type_utils.hpp"
3336

@@ -43,6 +46,7 @@ namespace py = pybind11;
4346
namespace type_utils = dpctl::tensor::type_utils;
4447

4548
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
49+
using ext::common::init_dispatch_vector;
4650

4751
typedef sycl::event (*gesv_batch_impl_fn_ptr_t)(
4852
sycl::queue &,
@@ -425,10 +429,7 @@ struct GesvBatchContigFactory
425429

426430
void init_gesv_batch_dispatch_vector(void)
427431
{
428-
dpctl_td_ns::DispatchVectorBuilder<gesv_batch_impl_fn_ptr_t,
429-
GesvBatchContigFactory,
430-
dpctl_td_ns::num_types>
431-
contig;
432-
contig.populate_dispatch_vector(gesv_batch_dispatch_vector);
432+
init_dispatch_vector<gesv_batch_impl_fn_ptr_t, GesvBatchContigFactory>(
433+
gesv_batch_dispatch_vector);
433434
}
434435
} // namespace dpnp::extensions::lapack

dpnp/backend/extensions/lapack/gesvd.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30+
// utils extension header
31+
#include "ext/common.hpp"
32+
3033
// dpctl tensor headers
3134
#include "utils/type_utils.hpp"
3235

@@ -41,6 +44,8 @@ namespace mkl_lapack = oneapi::mkl::lapack;
4144
namespace py = pybind11;
4245
namespace type_utils = dpctl::tensor::type_utils;
4346

47+
using ext::common::init_dispatch_table;
48+
4449
typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue &,
4550
const oneapi::mkl::jobsvd,
4651
const oneapi::mkl::jobsvd,
@@ -227,9 +232,7 @@ struct GesvdContigFactory
227232

228233
void init_gesvd_dispatch_table(void)
229234
{
230-
dpctl_td_ns::DispatchTableBuilder<gesvd_impl_fn_ptr_t, GesvdContigFactory,
231-
dpctl_td_ns::num_types>
232-
contig;
233-
contig.populate_dispatch_table(gesvd_dispatch_table);
235+
init_dispatch_table<gesvd_impl_fn_ptr_t, GesvdContigFactory>(
236+
gesvd_dispatch_table);
234237
}
235238
} // namespace dpnp::extensions::lapack

dpnp/backend/extensions/lapack/gesvd_batch.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30+
// utils extension header
31+
#include "ext/common.hpp"
32+
3033
// dpctl tensor headers
3134
#include "utils/type_utils.hpp"
3235

@@ -41,6 +44,8 @@ namespace mkl_lapack = oneapi::mkl::lapack;
4144
namespace py = pybind11;
4245
namespace type_utils = dpctl::tensor::type_utils;
4346

47+
using ext::common::init_dispatch_table;
48+
4449
typedef sycl::event (*gesvd_batch_impl_fn_ptr_t)(
4550
sycl::queue &,
4651
const oneapi::mkl::jobsvd,
@@ -296,10 +301,7 @@ struct GesvdBatchContigFactory
296301

297302
void init_gesvd_batch_dispatch_table(void)
298303
{
299-
dpctl_td_ns::DispatchTableBuilder<gesvd_batch_impl_fn_ptr_t,
300-
GesvdBatchContigFactory,
301-
dpctl_td_ns::num_types>
302-
contig;
303-
contig.populate_dispatch_table(gesvd_batch_dispatch_table);
304+
init_dispatch_table<gesvd_batch_impl_fn_ptr_t, GesvdBatchContigFactory>(
305+
gesvd_batch_dispatch_table);
304306
}
305307
} // namespace dpnp::extensions::lapack

dpnp/backend/extensions/lapack/getrf.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30+
// utils extension header
31+
#include "ext/common.hpp"
32+
3033
// dpctl tensor headers
3134
#include "utils/memory_overlap.hpp"
3235
#include "utils/sycl_alloc_utils.hpp"
@@ -35,14 +38,14 @@
3538
#include "getrf.hpp"
3639
#include "types_matrix.hpp"
3740

38-
#include "dpnp_utils.hpp"
39-
4041
namespace dpnp::extensions::lapack
4142
{
4243
namespace mkl_lapack = oneapi::mkl::lapack;
4344
namespace py = pybind11;
4445
namespace type_utils = dpctl::tensor::type_utils;
4546

47+
using ext::common::init_dispatch_vector;
48+
4649
typedef sycl::event (*getrf_impl_fn_ptr_t)(sycl::queue &,
4750
const std::int64_t,
4851
const std::int64_t,
@@ -262,9 +265,7 @@ struct GetrfContigFactory
262265

263266
void init_getrf_dispatch_vector(void)
264267
{
265-
dpctl_td_ns::DispatchVectorBuilder<getrf_impl_fn_ptr_t, GetrfContigFactory,
266-
dpctl_td_ns::num_types>
267-
contig;
268-
contig.populate_dispatch_vector(getrf_dispatch_vector);
268+
init_dispatch_vector<getrf_impl_fn_ptr_t, GetrfContigFactory>(
269+
getrf_dispatch_vector);
269270
}
270271
} // namespace dpnp::extensions::lapack

dpnp/backend/extensions/lapack/getrf_batch.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <pybind11/pybind11.h>
2929

30+
// utils extension header
31+
#include "ext/common.hpp"
32+
3033
// dpctl tensor headers
3134
#include "utils/memory_overlap.hpp"
3235
#include "utils/sycl_alloc_utils.hpp"
@@ -35,14 +38,14 @@
3538
#include "getrf.hpp"
3639
#include "types_matrix.hpp"
3740

38-
#include "dpnp_utils.hpp"
39-
4041
namespace dpnp::extensions::lapack
4142
{
4243
namespace mkl_lapack = oneapi::mkl::lapack;
4344
namespace py = pybind11;
4445
namespace type_utils = dpctl::tensor::type_utils;
4546

47+
using ext::common::init_dispatch_vector;
48+
4649
typedef sycl::event (*getrf_batch_impl_fn_ptr_t)(
4750
sycl::queue &,
4851
std::int64_t,
@@ -295,10 +298,7 @@ struct GetrfBatchContigFactory
295298

296299
void init_getrf_batch_dispatch_vector(void)
297300
{
298-
dpctl_td_ns::DispatchVectorBuilder<getrf_batch_impl_fn_ptr_t,
299-
GetrfBatchContigFactory,
300-
dpctl_td_ns::num_types>
301-
contig;
302-
contig.populate_dispatch_vector(getrf_batch_dispatch_vector);
301+
init_dispatch_vector<getrf_batch_impl_fn_ptr_t, GetrfBatchContigFactory>(
302+
getrf_batch_dispatch_vector);
303303
}
304304
} // namespace dpnp::extensions::lapack

0 commit comments

Comments
 (0)