Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added implementation of `dpnp.hanning` [#2358](https://github.com/IntelPython/dpnp/pull/2358)
* Added implementation of `dpnp.blackman` [#2363](https://github.com/IntelPython/dpnp/pull/2363)
* Added implementation of `dpnp.bartlett` [#2366](https://github.com/IntelPython/dpnp/pull/2366)
* Added implementation of `dpnp.kaiser` [#2387](https://github.com/IntelPython/dpnp/pull/2387)

### Changed

Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/extensions/window/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

set(python_module_name _window_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/kaiser.cpp
${CMAKE_CURRENT_SOURCE_DIR}/window_py.cpp
)

Expand Down
36 changes: 27 additions & 9 deletions dpnp/backend/extensions/window/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ sycl::event window_impl(sycl::queue &q,
return window_ev;
}

std::pair<sycl::event, sycl::event>
py_window(sycl::queue &exec_q,
template <typename funcPtrT>
std::tuple<size_t, char *, funcPtrT>
window_fn(sycl::queue &exec_q,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends,
const window_fn_ptr_t *window_dispatch_vector)
const funcPtrT *window_dispatch_vector)
{
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);

Expand All @@ -92,30 +92,48 @@ std::pair<sycl::event, sycl::event>

size_t nelems = result.get_size();
if (nelems == 0) {
return std::make_pair(sycl::event{}, sycl::event{});
return std::make_tuple(nelems, nullptr, nullptr);
}

int result_typenum = result.get_typenum();
auto array_types = dpctl_td_ns::usm_ndarray_types();
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
auto fn = window_dispatch_vector[result_type_id];
funcPtrT fn = window_dispatch_vector[result_type_id];

if (fn == nullptr) {
throw std::runtime_error("Type of given array is not supported");
}

char *result_typeless_ptr = result.get_data();
return std::make_tuple(nelems, result_typeless_ptr, fn);
}

inline std::pair<sycl::event, sycl::event>
py_window(sycl::queue &exec_q,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends,
const window_fn_ptr_t *window_dispatch_vector)
{
auto [nelems, result_typeless_ptr, fn] =
window_fn<window_fn_ptr_t>(exec_q, result, window_dispatch_vector);

if (nelems == 0) {
return std::make_pair(sycl::event{}, sycl::event{});
}

sycl::event window_ev = fn(exec_q, result_typeless_ptr, nelems, depends);
sycl::event args_ev =
dpctl::utils::keep_args_alive(exec_q, {result}, {window_ev});

return std::make_pair(args_ev, window_ev);
}

template <template <typename fnT, typename T> typename factoryT>
void init_window_dispatch_vectors(window_fn_ptr_t window_dispatch_vector[])
template <typename funcPtrT,
template <typename fnT, typename T>
typename factoryT>
void init_window_dispatch_vectors(funcPtrT window_dispatch_vector[])
{
dpctl_td_ns::DispatchVectorBuilder<window_fn_ptr_t, factoryT,
dpctl_td_ns::DispatchVectorBuilder<funcPtrT, factoryT,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_vector(window_dispatch_vector);
Expand Down
155 changes: 155 additions & 0 deletions dpnp/backend/extensions/window/kaiser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
//*****************************************************************************
// Copyright (c) 2025, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include "kaiser.hpp"
#include "common.hpp"

#include "utils/output_validation.hpp"
#include "utils/type_dispatch.hpp"
#include "utils/type_utils.hpp"

#include <sycl/sycl.hpp>

/**
* Version of SYCL DPC++ 2025.1 compiler where an issue with
* sycl::ext::intel::math::cyl_bessel_i0(x) is fully resolved.
*/
#ifndef __SYCL_COMPILER_BESSEL_I0_SUPPORT
#define __SYCL_COMPILER_BESSEL_I0_SUPPORT 20241208L
#endif

#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
#include <sycl/ext/intel/math.hpp>
#endif

#include "../kernels/elementwise_functions/i0.hpp"

namespace dpnp::extensions::window
{
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

typedef sycl::event (*kaiser_fn_ptr_t)(sycl::queue &,
char *,
const std::size_t,
const py::object &,
const std::vector<sycl::event> &);

static kaiser_fn_ptr_t kaiser_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
class KaiserFunctor
{
private:
T *data = nullptr;
const std::size_t N;
const T beta;

public:
KaiserFunctor(T *data, const std::size_t N, const T beta)
: data(data), N(N), beta(beta)
{
}

void operator()(sycl::id<1> id) const
{
#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
using sycl::ext::intel::math::cyl_bessel_i0;
#else
using dpnp::kernels::i0::impl::cyl_bessel_i0;
#endif

const auto i = id.get(0);
const T alpha = (N - 1) / T(2);
const T tmp = (i - alpha) / alpha;
data[i] = cyl_bessel_i0(beta * sycl::sqrt(1 - tmp * tmp)) /
cyl_bessel_i0(beta);
}
};

template <typename T, template <typename> class Functor>
sycl::event kaiser_impl(sycl::queue &q,
char *result,
const std::size_t nelems,
const py::object &py_beta,
const std::vector<sycl::event> &depends)
{
dpctl::tensor::type_utils::validate_type_for_device<T>(q);

T *res = reinterpret_cast<T *>(result);
const T beta = py::cast<const T>(py_beta);

sycl::event kaiser_ev = q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using KaiserKernel = Functor<T>;
cgh.parallel_for<KaiserKernel>(sycl::range<1>(nelems),
KaiserKernel(res, nelems, beta));
});

return kaiser_ev;
}

template <typename fnT, typename T>
struct KaiserFactory
{
fnT get()
{
if constexpr (std::is_floating_point_v<T>) {
return kaiser_impl<T, KaiserFunctor>;
}
else {
return nullptr;
}
}
};

std::pair<sycl::event, sycl::event>
py_kaiser(sycl::queue &exec_q,
const py::object &py_beta,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends)
{
auto [nelems, result_typeless_ptr, fn] =
window_fn<kaiser_fn_ptr_t>(exec_q, result, kaiser_dispatch_vector);

if (nelems == 0) {
return std::make_pair(sycl::event{}, sycl::event{});
}

sycl::event kaiser_ev =
fn(exec_q, result_typeless_ptr, nelems, py_beta, depends);
sycl::event args_ev =
dpctl::utils::keep_args_alive(exec_q, {result}, {kaiser_ev});

return std::make_pair(args_ev, kaiser_ev);
}

void init_kaiser_dispatch_vectors()
{
init_window_dispatch_vectors<kaiser_fn_ptr_t, KaiserFactory>(
kaiser_dispatch_vector);
}

} // namespace dpnp::extensions::window
41 changes: 41 additions & 0 deletions dpnp/backend/extensions/window/kaiser.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//*****************************************************************************
// Copyright (c) 2025, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <dpctl4pybind11.hpp>
#include <sycl/sycl.hpp>

namespace dpnp::extensions::window
{
extern std::pair<sycl::event, sycl::event>
py_kaiser(sycl::queue &exec_q,
const py::object &beta,
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends);

extern void init_kaiser_dispatch_vectors(void);

} // namespace dpnp::extensions::window
21 changes: 17 additions & 4 deletions dpnp/backend/extensions/window/window_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "common.hpp"
#include "hamming.hpp"
#include "hanning.hpp"
#include "kaiser.hpp"

namespace window_ns = dpnp::extensions::window;
namespace py = pybind11;
Expand All @@ -54,7 +55,8 @@ PYBIND11_MODULE(_window_impl, m)

{
window_ns::init_window_dispatch_vectors<
window_ns::kernels::BartlettFactory>(bartlett_dispatch_vector);
window_ns::window_fn_ptr_t, window_ns::kernels::BartlettFactory>(
bartlett_dispatch_vector);

auto bartlett_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
const event_vecT &depends = {}) {
Expand All @@ -69,7 +71,8 @@ PYBIND11_MODULE(_window_impl, m)

{
window_ns::init_window_dispatch_vectors<
window_ns::kernels::BlackmanFactory>(blackman_dispatch_vector);
window_ns::window_fn_ptr_t, window_ns::kernels::BlackmanFactory>(
blackman_dispatch_vector);

auto blackman_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
const event_vecT &depends = {}) {
Expand All @@ -84,7 +87,8 @@ PYBIND11_MODULE(_window_impl, m)

{
window_ns::init_window_dispatch_vectors<
window_ns::kernels::HammingFactory>(hamming_dispatch_vector);
window_ns::window_fn_ptr_t, window_ns::kernels::HammingFactory>(
hamming_dispatch_vector);

auto hamming_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
const event_vecT &depends = {}) {
Expand All @@ -99,7 +103,8 @@ PYBIND11_MODULE(_window_impl, m)

{
window_ns::init_window_dispatch_vectors<
window_ns::kernels::HanningFactory>(hanning_dispatch_vector);
window_ns::window_fn_ptr_t, window_ns::kernels::HanningFactory>(
hanning_dispatch_vector);

auto hanning_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
const event_vecT &depends = {}) {
Expand All @@ -111,4 +116,12 @@ PYBIND11_MODULE(_window_impl, m)
py::arg("sycl_queue"), py::arg("result"),
py::arg("depends") = py::list());
}

{
window_ns::init_kaiser_dispatch_vectors();

m.def("_kaiser", window_ns::py_kaiser, "Call Kaiser kernel",
py::arg("sycl_queue"), py::arg("beta"), py::arg("result"),
py::arg("depends") = py::list());
}
}
Loading
Loading