Skip to content

Commit e4440ad

Browse files
committed
Update window extension to use init_dispatch_vector() from common utils
1 parent 1c00e67 commit e4440ad

File tree

4 files changed

+25
-26
lines changed

4 files changed

+25
-26
lines changed

dpnp/backend/extensions/window/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ endif()
5757

5858
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
5959

60-
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
61-
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
60+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
61+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
6262

6363
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
6464
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

dpnp/backend/extensions/window/common.hpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#include <sycl/sycl.hpp>
3131

3232
#include "dpctl4pybind11.hpp"
33+
34+
// dpctl tensor headers
3335
#include "utils/output_validation.hpp"
3436
#include "utils/type_dispatch.hpp"
3537
#include "utils/type_utils.hpp"
@@ -127,18 +129,4 @@ inline std::pair<sycl::event, sycl::event>
127129

128130
return std::make_pair(args_ev, window_ev);
129131
}
130-
131-
template <typename funcPtrT,
132-
template <typename fnT, typename T>
133-
typename factoryT>
134-
void init_window_dispatch_vectors(funcPtrT window_dispatch_vector[])
135-
{
136-
dpctl_td_ns::DispatchVectorBuilder<funcPtrT, factoryT,
137-
dpctl_td_ns::num_types>
138-
contig;
139-
contig.populate_dispatch_vector(window_dispatch_vector);
140-
141-
return;
142-
}
143-
144132
} // namespace dpnp::extensions::window

dpnp/backend/extensions/window/kaiser.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,24 @@
2626
#include "kaiser.hpp"
2727
#include "common.hpp"
2828

29+
// utils extension header
30+
#include "ext/common.hpp"
31+
32+
// dpctl tensor headers
2933
#include "utils/output_validation.hpp"
3034
#include "utils/type_dispatch.hpp"
3135
#include "utils/type_utils.hpp"
3236

3337
#include <sycl/sycl.hpp>
3438

35-
#include "../kernels/elementwise_functions/i0.hpp"
39+
#include "kernels/elementwise_functions/i0.hpp"
3640

3741
namespace dpnp::extensions::window
3842
{
3943
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4044

45+
using ext::common::init_dispatch_vector;
46+
4147
typedef sycl::event (*kaiser_fn_ptr_t)(sycl::queue &,
4248
char *,
4349
const std::size_t,
@@ -132,7 +138,7 @@ std::pair<sycl::event, sycl::event>
132138

133139
void init_kaiser_dispatch_vectors()
134140
{
135-
init_window_dispatch_vectors<kaiser_fn_ptr_t, KaiserFactory>(
141+
init_dispatch_vector<kaiser_fn_ptr_t, KaiserFactory>(
136142
kaiser_dispatch_vector);
137143
}
138144

dpnp/backend/extensions/window/window_py.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,13 @@
3737
#include "hanning.hpp"
3838
#include "kaiser.hpp"
3939

40+
// utils extension header
41+
#include "ext/common.hpp"
42+
4043
namespace window_ns = dpnp::extensions::window;
4144
namespace py = pybind11;
45+
46+
using ext::common::init_dispatch_vector;
4247
using window_ns::window_fn_ptr_t;
4348

4449
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
@@ -54,8 +59,8 @@ PYBIND11_MODULE(_window_impl, m)
5459
using event_vecT = std::vector<sycl::event>;
5560

5661
{
57-
window_ns::init_window_dispatch_vectors<
58-
window_ns::window_fn_ptr_t, window_ns::kernels::BartlettFactory>(
62+
init_dispatch_vector<window_ns::window_fn_ptr_t,
63+
window_ns::kernels::BartlettFactory>(
5964
bartlett_dispatch_vector);
6065

6166
auto bartlett_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
@@ -70,8 +75,8 @@ PYBIND11_MODULE(_window_impl, m)
7075
}
7176

7277
{
73-
window_ns::init_window_dispatch_vectors<
74-
window_ns::window_fn_ptr_t, window_ns::kernels::BlackmanFactory>(
78+
init_dispatch_vector<window_ns::window_fn_ptr_t,
79+
window_ns::kernels::BlackmanFactory>(
7580
blackman_dispatch_vector);
7681

7782
auto blackman_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
@@ -86,8 +91,8 @@ PYBIND11_MODULE(_window_impl, m)
8691
}
8792

8893
{
89-
window_ns::init_window_dispatch_vectors<
90-
window_ns::window_fn_ptr_t, window_ns::kernels::HammingFactory>(
94+
init_dispatch_vector<window_ns::window_fn_ptr_t,
95+
window_ns::kernels::HammingFactory>(
9196
hamming_dispatch_vector);
9297

9398
auto hamming_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
@@ -102,8 +107,8 @@ PYBIND11_MODULE(_window_impl, m)
102107
}
103108

104109
{
105-
window_ns::init_window_dispatch_vectors<
106-
window_ns::window_fn_ptr_t, window_ns::kernels::HanningFactory>(
110+
init_dispatch_vector<window_ns::window_fn_ptr_t,
111+
window_ns::kernels::HanningFactory>(
107112
hanning_dispatch_vector);
108113

109114
auto hanning_pyapi = [&](sycl::queue &exec_q, const arrayT &result,

0 commit comments

Comments
 (0)