Skip to content

Commit 100094a

Browse files
vtavanaantonwolfy
andauthored
implement dpnp.fft.fft and dpnp.fft.ifft using pybind11 extension (#1879)
* implement fft and ifft using pybind11 extension * fixing minor issues * address comments * link to dpnp doc for FFT background info * add TODO * update backend structure * update CMakefile * update tests * add out keyword and fix an issue with negative stride * fix sphinx spelling issues * update for mkl-2024.2 * implement in-place fft * update to reuse dpctl function * fix an issue for out keyword given as usm_ndarray * fix a test * add TODO * extend descriptor template mkl_dft::domain is added to descriptor template for future impelemntation of REAL domain * add incorrectly removed header * address comments * fix empty array test * implement async memory * address comments * use dpnp_array * update a test * fix pre-commit * consistency with stock NumPy for coverage report --------- Co-authored-by: Anton <[email protected]>
1 parent 0dbc7e9 commit 100094a

22 files changed

+1915
-242
lines changed

doc/known_words.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
al
12
backend
23
bitwise
34
boolean
@@ -9,7 +10,9 @@ combinatorially
910
conda
1011
cubically
1112
Decompositions
13+
diag
1214
dimensionality
15+
discretized
1316
docstring
1417
dpctl
1518
dpnp
@@ -19,9 +22,11 @@ einsum
1922
endian
2023
eps
2124
epsneg
25+
et
2226
Extrema
2327
finfo
2428
finiteness
29+
Flannery
2530
Fortran
2631
Frobenius
2732
Hadamard
@@ -41,10 +46,12 @@ ndarray
4146
ndarrays
4247
ndim
4348
normed
49+
Nyquist
4450
oneAPI
4551
orthonormal
4652
Penrose
4753
Polyutils
54+
pre
4855
prepend
4956
prepending
5057
representable
@@ -58,14 +65,21 @@ subclasses
5865
subtype
5966
SyclDevice
6067
SyclQueue
68+
tensordot
69+
Teukolsky
6170
th
71+
tril
72+
triu
73+
Tukey
6274
ufunc
6375
ufuncs
6476
Unary
77+
unscaled
6578
unicode
6679
usm
6780
Vandermonde
6881
vectorized
82+
Vetterline
6983
von
7084
Weibull
7185
whitespace

doc/reference/fft.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,5 @@ Helper routines
6363
.. dpnp.fft.config.set_cufft_gpus
6464
.. dpnp.fft.config.get_plan_cache
6565
.. dpnp.fft.config.show_plan_cache_info
66+
67+
.. automodule:: dpnp.fft

dpnp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ endfunction()
5454

5555
add_subdirectory(backend)
5656
add_subdirectory(backend/extensions/blas)
57+
add_subdirectory(backend/extensions/fft)
5758
add_subdirectory(backend/extensions/lapack)
5859
add_subdirectory(backend/extensions/vm)
5960
add_subdirectory(backend/extensions/sycl_ext)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2024, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
#
13+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
# THE POSSIBILITY OF SUCH DAMAGE.
24+
# *****************************************************************************
25+
26+
27+
set(python_module_name _fft_impl)
28+
set(_module_src
29+
${CMAKE_CURRENT_SOURCE_DIR}/fft_py.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/in_place.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/out_of_place.cpp
32+
)
33+
34+
pybind11_add_module(${python_module_name} MODULE ${_module_src})
35+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src})
36+
37+
if (WIN32)
38+
if (${CMAKE_VERSION} VERSION_LESS "3.27")
39+
# this is a work-around for target_link_options inserting option after -link option, cause
40+
# linker to ignore it.
41+
set(CMAKE_CXX_LINK_FLAGS "${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel")
42+
endif()
43+
endif()
44+
45+
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
46+
47+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
48+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
49+
50+
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS})
51+
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})
52+
53+
if (WIN32)
54+
target_compile_options(${python_module_name} PRIVATE
55+
/clang:-fno-approx-func
56+
/clang:-fno-finite-math-only
57+
)
58+
else()
59+
target_compile_options(${python_module_name} PRIVATE
60+
-fno-approx-func
61+
-fno-finite-math-only
62+
)
63+
endif()
64+
65+
target_link_options(${python_module_name} PUBLIC -fsycl-device-code-split=per_kernel)
66+
67+
if (DPNP_GENERATE_COVERAGE)
68+
target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping)
69+
endif()
70+
71+
target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::DFT)
72+
73+
install(TARGETS ${python_module_name}
74+
DESTINATION "dpnp/backend/extensions/fft"
75+
)
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <oneapi/mkl.hpp>
29+
#include <pybind11/pybind11.h>
30+
#include <sycl/sycl.hpp>
31+
32+
namespace dpnp::extensions::fft
33+
{
34+
namespace mkl_dft = oneapi::mkl::dft;
35+
namespace py = pybind11;
36+
37+
template <mkl_dft::precision prec, mkl_dft::domain dom>
38+
class DescriptorWrapper
39+
{
40+
public:
41+
using descr_type = mkl_dft::descriptor<prec, dom>;
42+
43+
DescriptorWrapper(std::int64_t n) : descr_(n), queue_ptr_{} {}
44+
DescriptorWrapper(std::vector<std::int64_t> dimensions)
45+
: descr_(dimensions), queue_ptr_{}
46+
{
47+
}
48+
~DescriptorWrapper() {}
49+
50+
void commit(sycl::queue &q)
51+
{
52+
mkl_dft::precision fft_prec = get_precision();
53+
if (fft_prec == mkl_dft::precision::DOUBLE &&
54+
!q.get_device().has(sycl::aspect::fp64))
55+
{
56+
throw py::value_error("Descriptor is double precision but the "
57+
"device does not support double precision.");
58+
}
59+
60+
descr_.commit(q);
61+
queue_ptr_ = std::make_unique<sycl::queue>(q);
62+
}
63+
64+
descr_type &get_descriptor()
65+
{
66+
return descr_;
67+
}
68+
69+
const sycl::queue &get_queue() const
70+
{
71+
if (queue_ptr_) {
72+
return *queue_ptr_;
73+
}
74+
else {
75+
throw std::runtime_error(
76+
"Attempt to get queue when it is not yet set");
77+
}
78+
}
79+
80+
// config_param::DIMENSION
81+
template <typename valT = std::int64_t>
82+
const valT get_dim()
83+
{
84+
valT dim = -1;
85+
descr_.get_value(mkl_dft::config_param::DIMENSION, &dim);
86+
87+
return dim;
88+
}
89+
90+
// config_param::NUMBER_OF_TRANSFORMS
91+
template <typename valT = std::int64_t>
92+
const valT get_number_of_transforms()
93+
{
94+
valT transforms_count{};
95+
96+
descr_.get_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS,
97+
&transforms_count);
98+
return transforms_count;
99+
}
100+
101+
template <typename valT = std::int64_t>
102+
void set_number_of_transforms(const valT &num)
103+
{
104+
descr_.set_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS, num);
105+
}
106+
107+
// config_param::FWD_STRIDES
108+
template <typename valT = std::vector<std::int64_t>>
109+
const valT get_fwd_strides()
110+
{
111+
const typename valT::value_type dim = get_dim();
112+
113+
valT fwd_strides(dim + 1);
114+
descr_.get_value(mkl_dft::config_param::FWD_STRIDES,
115+
fwd_strides.data());
116+
return fwd_strides;
117+
}
118+
119+
template <typename valT = std::vector<std::int64_t>>
120+
void set_fwd_strides(const valT &strides)
121+
{
122+
const typename valT::value_type dim = get_dim();
123+
124+
if (static_cast<size_t>(dim + 1) != strides.size()) {
125+
throw py::value_error(
126+
"Strides length does not match descriptor's dimension");
127+
}
128+
descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides.data());
129+
}
130+
131+
// config_param::BWD_STRIDES
132+
template <typename valT = std::vector<std::int64_t>>
133+
const valT get_bwd_strides()
134+
{
135+
const typename valT::value_type dim = get_dim();
136+
137+
valT bwd_strides(dim + 1);
138+
descr_.get_value(mkl_dft::config_param::BWD_STRIDES,
139+
bwd_strides.data());
140+
return bwd_strides;
141+
}
142+
143+
template <typename valT = std::vector<std::int64_t>>
144+
void set_bwd_strides(const valT &strides)
145+
{
146+
const typename valT::value_type dim = get_dim();
147+
148+
if (static_cast<size_t>(dim + 1) != strides.size()) {
149+
throw py::value_error(
150+
"Strides length does not match descriptor's dimension");
151+
}
152+
descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides.data());
153+
}
154+
155+
// config_param::FWD_DISTANCE
156+
template <typename valT = std::int64_t>
157+
const valT get_fwd_distance()
158+
{
159+
valT dist = 0;
160+
161+
descr_.get_value(mkl_dft::config_param::FWD_DISTANCE, &dist);
162+
return dist;
163+
}
164+
165+
template <typename valT = std::int64_t>
166+
void set_fwd_distance(const valT &dist)
167+
{
168+
descr_.set_value(mkl_dft::config_param::FWD_DISTANCE, dist);
169+
}
170+
171+
// config_param::BWD_DISTANCE
172+
template <typename valT = std::int64_t>
173+
const valT get_bwd_distance()
174+
{
175+
valT dist = 0;
176+
177+
descr_.get_value(mkl_dft::config_param::BWD_DISTANCE, &dist);
178+
return dist;
179+
}
180+
181+
template <typename valT = std::int64_t>
182+
void set_bwd_distance(const valT &dist)
183+
{
184+
descr_.set_value(mkl_dft::config_param::BWD_DISTANCE, dist);
185+
}
186+
187+
// config_param::PLACEMENT
188+
bool get_in_place()
189+
{
190+
// TODO: replace when MKLD-10506 is implemented
191+
// mkl_dft::config_value placement;
192+
DFTI_CONFIG_VALUE placement;
193+
194+
descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
195+
// TODO: replace when MKLD-10506 is implemented
196+
// return (placement == mkl_dft::config_value::INPLACE);
197+
return (placement == DFTI_CONFIG_VALUE::DFTI_INPLACE);
198+
}
199+
200+
void set_in_place(const bool &in_place_request)
201+
{
202+
// TODO: replace when MKLD-10506 is implemented
203+
// descr_.set_value(mkl_dft::config_param::PLACEMENT, (in_place_request)
204+
// ? mkl_dft::config_value::INPLACE :
205+
// mkl_dft::config_value::NOT_INPLACE);
206+
descr_.set_value(mkl_dft::config_param::PLACEMENT,
207+
(in_place_request)
208+
? DFTI_CONFIG_VALUE::DFTI_INPLACE
209+
: DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE);
210+
}
211+
212+
// config_param::PRECISION
213+
mkl_dft::precision get_precision()
214+
{
215+
mkl_dft::precision fft_prec;
216+
217+
descr_.get_value(mkl_dft::config_param::PRECISION, &fft_prec);
218+
return fft_prec;
219+
}
220+
221+
// config_param::COMMIT_STATUS
222+
bool is_committed()
223+
{
224+
// TODO: replace when MKLD-10506 is implemented
225+
// mkl_dft::config_value committed;
226+
DFTI_CONFIG_VALUE committed;
227+
228+
descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
229+
// TODO: replace when MKLD-10506 is implemented
230+
// return (committed == mkl_dft::config_value::COMMITTED);
231+
return (committed == DFTI_CONFIG_VALUE::DFTI_COMMITTED);
232+
}
233+
234+
private:
235+
mkl_dft::descriptor<prec, dom> descr_;
236+
std::unique_ptr<sycl::queue> queue_ptr_;
237+
};
238+
239+
} // namespace dpnp::extensions::fft

0 commit comments

Comments
 (0)