Skip to content

Commit c8a2822

Browse files
author
Vahid Tavanashad
committed
implement dpnp.kaiser
1 parent e0b7932 commit c8a2822

File tree

9 files changed

+439
-14
lines changed

9 files changed

+439
-14
lines changed

dpnp/backend/extensions/window/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
set(python_module_name _window_impl)
2828
set(_module_src
29+
${CMAKE_CURRENT_SOURCE_DIR}/kaiser.cpp
2930
${CMAKE_CURRENT_SOURCE_DIR}/window_py.cpp
3031
)
3132

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include "kaiser.hpp"
27+
#include "utils/output_validation.hpp"
28+
#include "utils/type_dispatch.hpp"
29+
#include "utils/type_utils.hpp"
30+
#include <sycl/sycl.hpp>
31+
32+
/**
33+
* Version of SYCL DPC++ 2025.1 compiler where an issue with
34+
* sycl::ext::intel::math::cyl_bessel_i0(x) is fully resolved.
35+
*/
36+
#ifndef __SYCL_COMPILER_BESSEL_I0_SUPPORT
37+
#define __SYCL_COMPILER_BESSEL_I0_SUPPORT 20241208L
38+
#endif
39+
40+
#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
41+
#include <sycl/ext/intel/math.hpp>
42+
#endif
43+
44+
#include "../kernels/elementwise_functions/i0.hpp"
45+
46+
namespace dpnp::extensions::window
47+
{
48+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
49+
50+
typedef sycl::event (*kaiser_fn_ptr_t)(sycl::queue &,
51+
char *,
52+
const std::size_t,
53+
const float,
54+
const std::vector<sycl::event> &);
55+
56+
static kaiser_fn_ptr_t kaiser_dispatch_vector[dpctl_td_ns::num_types];
57+
58+
template <typename T>
59+
class KaiserFunctor
60+
{
61+
private:
62+
T *data = nullptr;
63+
const std::size_t N;
64+
const float beta;
65+
66+
public:
67+
KaiserFunctor(T *data, const std::size_t N, const float beta)
68+
: data(data), N(N), beta(beta)
69+
{
70+
}
71+
72+
void operator()(sycl::id<1> id) const
73+
{
74+
#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
75+
using sycl::ext::intel::math::cyl_bessel_i0;
76+
#else
77+
using dpnp::kernels::i0::impl::cyl_bessel_i0;
78+
#endif
79+
80+
const auto i = id.get(0);
81+
const T alpha = (N - 1) / T(2);
82+
const T tmp = (i - alpha) / alpha;
83+
data[i] = cyl_bessel_i0(beta * sycl::sqrt(1 - tmp * tmp)) /
84+
cyl_bessel_i0(beta);
85+
}
86+
};
87+
88+
template <typename T, template <typename> class Functor>
89+
sycl::event kaiser_impl(sycl::queue &q,
90+
char *result,
91+
const std::size_t nelems,
92+
const float beta,
93+
const std::vector<sycl::event> &depends)
94+
{
95+
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
96+
97+
T *res = reinterpret_cast<T *>(result);
98+
99+
sycl::event kaiser_ev = q.submit([&](sycl::handler &cgh) {
100+
cgh.depends_on(depends);
101+
102+
using KaiserKernel = Functor<T>;
103+
cgh.parallel_for<KaiserKernel>(sycl::range<1>(nelems),
104+
KaiserKernel(res, nelems, beta));
105+
});
106+
107+
return kaiser_ev;
108+
}
109+
110+
template <typename fnT, typename T>
111+
struct KaiserFactory
112+
{
113+
fnT get()
114+
{
115+
if constexpr (std::is_floating_point_v<T>) {
116+
return kaiser_impl<T, KaiserFunctor>;
117+
}
118+
else {
119+
return nullptr;
120+
}
121+
}
122+
};
123+
124+
std::pair<sycl::event, sycl::event>
125+
py_kaiser(sycl::queue &exec_q,
126+
const float beta,
127+
const dpctl::tensor::usm_ndarray &result,
128+
const std::vector<sycl::event> &depends)
129+
{
130+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
131+
132+
int nd = result.get_ndim();
133+
if (nd != 1) {
134+
throw py::value_error("Array should be 1d");
135+
}
136+
137+
if (!dpctl::utils::queues_are_compatible(exec_q, {result.get_queue()})) {
138+
throw py::value_error(
139+
"Execution queue is not compatible with allocation queue.");
140+
}
141+
142+
const bool is_result_c_contig = result.is_c_contiguous();
143+
if (!is_result_c_contig) {
144+
throw py::value_error("The result input array is not c-contiguous.");
145+
}
146+
147+
size_t nelems = result.get_size();
148+
if (nelems == 0) {
149+
return std::make_pair(sycl::event{}, sycl::event{});
150+
}
151+
152+
int result_typenum = result.get_typenum();
153+
auto array_types = dpctl_td_ns::usm_ndarray_types();
154+
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
155+
auto fn = kaiser_dispatch_vector[result_type_id];
156+
157+
if (fn == nullptr) {
158+
throw std::runtime_error("Type of given array is not supported");
159+
}
160+
161+
char *result_typeless_ptr = result.get_data();
162+
sycl::event kaiser_ev =
163+
fn(exec_q, result_typeless_ptr, nelems, beta, depends);
164+
sycl::event args_ev =
165+
dpctl::utils::keep_args_alive(exec_q, {result}, {kaiser_ev});
166+
167+
return std::make_pair(args_ev, kaiser_ev);
168+
}
169+
170+
void init_kaiser_dispatch_vectors()
171+
{
172+
dpctl_td_ns::DispatchVectorBuilder<kaiser_fn_ptr_t, KaiserFactory,
173+
dpctl_td_ns::num_types>
174+
contig;
175+
contig.populate_dispatch_vector(kaiser_dispatch_vector);
176+
177+
return;
178+
}
179+
180+
} // namespace dpnp::extensions::window
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <dpctl4pybind11.hpp>
29+
#include <sycl/sycl.hpp>
30+
31+
namespace dpnp::extensions::window
32+
{
33+
extern std::pair<sycl::event, sycl::event>
34+
py_kaiser(sycl::queue &exec_q,
35+
const float beta,
36+
const dpctl::tensor::usm_ndarray &result,
37+
const std::vector<sycl::event> &depends);
38+
39+
extern void init_kaiser_dispatch_vectors(void);
40+
41+
} // namespace dpnp::extensions::window

dpnp/backend/extensions/window/window_py.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "common.hpp"
3636
#include "hamming.hpp"
3737
#include "hanning.hpp"
38+
#include "kaiser.hpp"
3839

3940
namespace window_ns = dpnp::extensions::window;
4041
namespace py = pybind11;
@@ -111,4 +112,12 @@ PYBIND11_MODULE(_window_impl, m)
111112
py::arg("sycl_queue"), py::arg("result"),
112113
py::arg("depends") = py::list());
113114
}
115+
116+
{
117+
window_ns::init_kaiser_dispatch_vectors();
118+
119+
m.def("_kaiser", window_ns::py_kaiser, "Call Kaiser kernel",
120+
py::arg("sycl_queue"), py::arg("beta"), py::arg("result"),
121+
py::arg("depends") = py::list());
122+
}
114123
}

0 commit comments

Comments
 (0)