Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added implementation of `dpnp.hanning` [#2358](https://github.com/IntelPython/dpnp/pull/2358)
* Added implementation of `dpnp.blackman` [#2363](https://github.com/IntelPython/dpnp/pull/2363)
* Added implementation of `dpnp.bartlett` [#2366](https://github.com/IntelPython/dpnp/pull/2366)
* Added implementation of `dpnp.kaiser` [#2387](https://github.com/IntelPython/dpnp/pull/2387)

### Changed

Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/extensions/window/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

set(python_module_name _window_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/kaiser.cpp
${CMAKE_CURRENT_SOURCE_DIR}/window_py.cpp
)

Expand Down
36 changes: 27 additions & 9 deletions dpnp/backend/extensions/window/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ sycl::event window_impl(sycl::queue &q,
return window_ev;
}

std::pair<sycl::event, sycl::event>
py_window(sycl::queue &exec_q,
template <typename funcPtrT>
std::tuple<size_t, char *, funcPtrT>
window_fn(sycl::queue &exec_q,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends,
const window_fn_ptr_t *window_dispatch_vector)
const funcPtrT *window_dispatch_vector)
{
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);

Expand All @@ -92,30 +92,48 @@ std::pair<sycl::event, sycl::event>

size_t nelems = result.get_size();
if (nelems == 0) {
return std::make_pair(sycl::event{}, sycl::event{});
return std::make_tuple(nelems, nullptr, nullptr);
}

int result_typenum = result.get_typenum();
auto array_types = dpctl_td_ns::usm_ndarray_types();
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
auto fn = window_dispatch_vector[result_type_id];
funcPtrT fn = window_dispatch_vector[result_type_id];

if (fn == nullptr) {
throw std::runtime_error("Type of given array is not supported");
}

char *result_typeless_ptr = result.get_data();
return std::make_tuple(nelems, result_typeless_ptr, fn);
}

inline std::pair<sycl::event, sycl::event>
py_window(sycl::queue &exec_q,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends,
const window_fn_ptr_t *window_dispatch_vector)
{
auto [nelems, result_typeless_ptr, fn] =
window_fn<window_fn_ptr_t>(exec_q, result, window_dispatch_vector);

if (nelems == 0) {
return std::make_pair(sycl::event{}, sycl::event{});
}

sycl::event window_ev = fn(exec_q, result_typeless_ptr, nelems, depends);
sycl::event args_ev =
dpctl::utils::keep_args_alive(exec_q, {result}, {window_ev});

return std::make_pair(args_ev, window_ev);
}

template <template <typename fnT, typename T> typename factoryT>
void init_window_dispatch_vectors(window_fn_ptr_t window_dispatch_vector[])
template <typename funcPtrT,
template <typename fnT, typename T>
typename factoryT>
void init_window_dispatch_vectors(funcPtrT window_dispatch_vector[])
{
dpctl_td_ns::DispatchVectorBuilder<window_fn_ptr_t, factoryT,
dpctl_td_ns::DispatchVectorBuilder<funcPtrT, factoryT,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_vector(window_dispatch_vector);
Expand Down
155 changes: 155 additions & 0 deletions dpnp/backend/extensions/window/kaiser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
//*****************************************************************************
// Copyright (c) 2025, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include "kaiser.hpp"
#include "common.hpp"

#include "utils/output_validation.hpp"
#include "utils/type_dispatch.hpp"
#include "utils/type_utils.hpp"

#include <sycl/sycl.hpp>

/**
* Version of SYCL DPC++ 2025.1 compiler where an issue with
* sycl::ext::intel::math::cyl_bessel_i0(x) is fully resolved.
*/
#ifndef __SYCL_COMPILER_BESSEL_I0_SUPPORT
#define __SYCL_COMPILER_BESSEL_I0_SUPPORT 20241208L
#endif

#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
#include <sycl/ext/intel/math.hpp>
#endif

#include "../kernels/elementwise_functions/i0.hpp"

namespace dpnp::extensions::window
{
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

typedef sycl::event (*kaiser_fn_ptr_t)(sycl::queue &,
char *,
const std::size_t,
const py::object &,
const std::vector<sycl::event> &);

static kaiser_fn_ptr_t kaiser_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
class KaiserFunctor
{
private:
T *data = nullptr;
const std::size_t N;
const T beta;

public:
KaiserFunctor(T *data, const std::size_t N, const T beta)
: data(data), N(N), beta(beta)
{
}

void operator()(sycl::id<1> id) const
{
#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
using sycl::ext::intel::math::cyl_bessel_i0;
#else
using dpnp::kernels::i0::impl::cyl_bessel_i0;
#endif

const auto i = id.get(0);
const T alpha = (N - 1) / T(2);
const T tmp = (i - alpha) / alpha;
data[i] = cyl_bessel_i0(beta * sycl::sqrt(1 - tmp * tmp)) /
cyl_bessel_i0(beta);
}
};

template <typename T, template <typename> class Functor>
sycl::event kaiser_impl(sycl::queue &q,
char *result,
const std::size_t nelems,
const py::object &py_beta,
const std::vector<sycl::event> &depends)
{
dpctl::tensor::type_utils::validate_type_for_device<T>(q);

T *res = reinterpret_cast<T *>(result);
const T beta = py::cast<const T>(py_beta);

sycl::event kaiser_ev = q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using KaiserKernel = Functor<T>;
cgh.parallel_for<KaiserKernel>(sycl::range<1>(nelems),
KaiserKernel(res, nelems, beta));
});

return kaiser_ev;
}

template <typename fnT, typename T>
struct KaiserFactory
{
fnT get()
{
if constexpr (std::is_floating_point_v<T>) {
return kaiser_impl<T, KaiserFunctor>;
}
else {
return nullptr;
}
}
};

std::pair<sycl::event, sycl::event>
py_kaiser(sycl::queue &exec_q,
const py::object &py_beta,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends)
{
auto [nelems, result_typeless_ptr, fn] =
window_fn<kaiser_fn_ptr_t>(exec_q, result, kaiser_dispatch_vector);

if (nelems == 0) {
return std::make_pair(sycl::event{}, sycl::event{});
}

sycl::event kaiser_ev =
fn(exec_q, result_typeless_ptr, nelems, py_beta, depends);
sycl::event args_ev =
dpctl::utils::keep_args_alive(exec_q, {result}, {kaiser_ev});

return std::make_pair(args_ev, kaiser_ev);
}

void init_kaiser_dispatch_vectors()
{
init_window_dispatch_vectors<kaiser_fn_ptr_t, KaiserFactory>(
kaiser_dispatch_vector);
}

} // namespace dpnp::extensions::window
41 changes: 41 additions & 0 deletions dpnp/backend/extensions/window/kaiser.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//*****************************************************************************
// Copyright (c) 2025, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <dpctl4pybind11.hpp>
#include <sycl/sycl.hpp>

namespace dpnp::extensions::window
{
extern std::pair<sycl::event, sycl::event>
py_kaiser(sycl::queue &exec_q,
const py::object &beta,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends);

extern void init_kaiser_dispatch_vectors(void);

} // namespace dpnp::extensions::window
21 changes: 17 additions & 4 deletions dpnp/backend/extensions/window/window_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "common.hpp"
#include "hamming.hpp"
#include "hanning.hpp"
#include "kaiser.hpp"

namespace window_ns = dpnp::extensions::window;
namespace py = pybind11;
Expand All @@ -54,7 +55,8 @@ PYBIND11_MODULE(_window_impl, m)

{
window_ns::init_window_dispatch_vectors<
window_ns::kernels::BartlettFactory>(bartlett_dispatch_vector);
window_ns::window_fn_ptr_t, window_ns::kernels::BartlettFactory>(
bartlett_dispatch_vector);

auto bartlett_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
const event_vecT &depends = {}) {
Expand All @@ -69,7 +71,8 @@ PYBIND11_MODULE(_window_impl, m)

{
window_ns::init_window_dispatch_vectors<
window_ns::kernels::BlackmanFactory>(blackman_dispatch_vector);
window_ns::window_fn_ptr_t, window_ns::kernels::BlackmanFactory>(
blackman_dispatch_vector);

auto blackman_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
const event_vecT &depends = {}) {
Expand All @@ -84,7 +87,8 @@ PYBIND11_MODULE(_window_impl, m)

{
window_ns::init_window_dispatch_vectors<
window_ns::kernels::HammingFactory>(hamming_dispatch_vector);
window_ns::window_fn_ptr_t, window_ns::kernels::HammingFactory>(
hamming_dispatch_vector);

auto hamming_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
const event_vecT &depends = {}) {
Expand All @@ -99,7 +103,8 @@ PYBIND11_MODULE(_window_impl, m)

{
window_ns::init_window_dispatch_vectors<
window_ns::kernels::HanningFactory>(hanning_dispatch_vector);
window_ns::window_fn_ptr_t, window_ns::kernels::HanningFactory>(
hanning_dispatch_vector);

auto hanning_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
const event_vecT &depends = {}) {
Expand All @@ -111,4 +116,12 @@ PYBIND11_MODULE(_window_impl, m)
py::arg("sycl_queue"), py::arg("result"),
py::arg("depends") = py::list());
}

{
window_ns::init_kaiser_dispatch_vectors();

m.def("_kaiser", window_ns::py_kaiser, "Call Kaiser kernel",
py::arg("sycl_queue"), py::arg("beta"), py::arg("result"),
py::arg("depends") = py::list());
}
}
Loading
Loading