Skip to content

Commit 324a88f

Browse files
authored
work-item strided transforms (#136)
* work-item strided transforms * move validation to before the committed_descriptor constructor * format * test tidying * remove layout from workitem-dispatcher * update comments * clarify distance for 1d kernel launch * format * add check that strided ffts fit in workitem * clarified README * clarify the use of stride and distance in dispatch_kernel_1d * added shortcut validation for batch_interleaved * rename descriptor_validate to descriptor_validation
1 parent d5fe215 commit 324a88f

16 files changed

+935
-409
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ portFFT is still in early development. The supported configurations are:
9696
* size in each dimension must be supported by 1D transforms
9797
* Arbitrary forward and backward scales
9898
* Arbitrary forward and backward offsets
99+
* Arbitrary strides and distance where the problem size + auxilary data fits in the registers of a single work-item.
99100

100101
Any 1D arbitrarily large input size that fits in global memory is supported, with a restriction that large input sizes should not have large prime factors.
101102
The largest prime factor depend on the device and the values set by `PORTFFT_REGISTERS_PER_WI` and `PORTFFT_SUBGROUP_SIZES`.
@@ -106,6 +107,8 @@ Any batch size is supported as long as the input and output data fits in global
106107

107108
By default the library assumes subgroup size of 32 is used. If that is not supported by the device it is running on, the subgroup size can be set using `PORTFFT_SUBGROUP_SIZES`.
108109

110+
Configurations that attempt to read from the same memory address from two separate batches of a transform are not supported.
111+
109112
## Known issues
110113

111114
* portFFT relies on SYCL specialization constants which have some limitations currently:

src/portfft/committed_descriptor_impl.hpp

Lines changed: 128 additions & 220 deletions
Large diffs are not rendered by default.

src/portfft/common/global.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc
151151
IdxGlobal outer_batch_offset = get_outer_batch_offset(factors, inner_batches, inclusive_scan, num_factors,
152152
level_num, iter_value, outer_batch_product, storage);
153153
if (level == detail::level::WORKITEM) {
154-
workitem_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
155-
input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,
156-
output_imag + outer_batch_offset, input_loc, batch_size, global_data, kh, static_cast<const Scalar*>(nullptr),
157-
store_modifier_data, static_cast<Scalar*>(nullptr), store_modifier_loc);
154+
workitem_impl<SubgroupSize, Scalar>(input + outer_batch_offset, output + outer_batch_offset,
155+
input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc,
156+
batch_size, global_data, kh, static_cast<const Scalar*>(nullptr),
157+
store_modifier_data, static_cast<Scalar*>(nullptr), store_modifier_loc);
158158
} else if (level == detail::level::SUBGROUP) {
159159
subgroup_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
160160
input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,

src/portfft/common/logging.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ struct logging_config {
5252
}
5353
#endif
5454
}
55-
char* log_trace_str = getenv("PORTFFT_LOG_TRACE");
55+
char* log_trace_str = getenv("PORTFFT_LOG_TRACES");
5656
if (log_trace_str != nullptr) {
5757
log_trace = static_cast<bool>(atoi(log_trace_str));
58-
#ifndef PORTFFT_LOG_TRACE
58+
#ifndef PORTFFT_LOG_TRACES
5959
if (log_trace) {
6060
std::cerr << "Can not enable logging of traces if it is disabled at compile time." << std::endl;
6161
}
@@ -281,7 +281,7 @@ struct global_data_struct {
281281
*/
282282
template <typename... Ts>
283283
PORTFFT_INLINE void log_message_global([[maybe_unused]] Ts... messages) {
284-
#ifdef PORTFFT_LOG_TRACE
284+
#ifdef PORTFFT_LOG_TRACES
285285
if (global_logging_config.log_trace && it.get_global_id(0) == 0) {
286286
log_message_impl(messages...);
287287
}

src/portfft/descriptor.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
#include <numeric>
2828
#include <vector>
2929

30+
#include "committed_descriptor.hpp"
3031
#include "defines.hpp"
32+
#include "descriptor_validation.hpp"
3133
#include "enums.hpp"
3234

33-
#include "committed_descriptor.hpp"
34-
3535
namespace portfft {
3636

3737
/**
@@ -151,6 +151,7 @@ struct descriptor {
151151
*/
152152
committed_descriptor<Scalar, Domain> commit(sycl::queue& queue) {
153153
PORTFFT_LOG_FUNCTION_ENTRY();
154+
detail::validate::validate_descriptor(*this);
154155
return {*this, queue};
155156
}
156157

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
/***************************************************************************
2+
*
3+
* Copyright (C) Codeplay Software Ltd.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
* Codeplay's portFFT
18+
*
19+
**************************************************************************/
20+
21+
#ifndef PORTFFT_DESCRIPTOR_VALIDATE_HPP
22+
#define PORTFFT_DESCRIPTOR_VALIDATE_HPP
23+
24+
#include <string_view>
25+
26+
#include "common/exceptions.hpp"
27+
#include "common/workitem.hpp"
28+
#include "enums.hpp"
29+
#include "utils.hpp"
30+
31+
namespace portfft::detail::validate {
32+
33+
/**
34+
* Throw an exception if the lengths are invalid when looked at in isolation.
35+
*
36+
* @param lengths the dimensions of the transform
37+
*/
38+
inline void validate_lengths(const std::vector<std::size_t>& lengths) {
39+
if (lengths.empty()) {
40+
throw invalid_configuration("Invalid lengths, must have at least 1 dimension");
41+
}
42+
for (std::size_t i = 0; i < lengths.size(); ++i) {
43+
if (lengths[i] == 0) {
44+
throw invalid_configuration("Invalid lengths[", i, "]=", lengths[i], ", must be positive");
45+
}
46+
}
47+
}
48+
49+
/**
50+
* Throw an exception if the layout is unsupported.
51+
*
52+
* @tparam Scalar the scalar type for the transform
53+
* @param lengths the dimensions of the transform
54+
* @param forward_layout the layout of the forward domain
55+
* @param backward_layout the layout of the backward domain
56+
*/
57+
template <typename Scalar>
58+
inline void validate_layout(const std::vector<std::size_t>& lengths, portfft::detail::layout forward_layout,
59+
portfft::detail::layout backward_layout) {
60+
if (lengths.size() > 1) {
61+
const bool supported_layout =
62+
forward_layout == portfft::detail::layout::PACKED && backward_layout == portfft::detail::layout::PACKED;
63+
if (!supported_layout) {
64+
throw unsupported_configuration("Multi-dimensional transforms are only supported with default data layout");
65+
}
66+
}
67+
if (forward_layout == portfft::detail::layout::UNPACKED || backward_layout == portfft::detail::layout::UNPACKED) {
68+
if (!portfft::detail::fits_in_wi<Scalar>(lengths.back())) {
69+
throw unsupported_configuration(
70+
"Arbitrary strides and distances are only supported for sizes that fit in the registers of a single "
71+
"work-item");
72+
}
73+
}
74+
}
75+
76+
/**
77+
* Throw an exception if individual stride, distance and number_of_transforms values are invalid/inconsistent.
78+
*
79+
* @param lengths the dimensions of the transform
80+
* @param number_of_transforms the number of batches
81+
* @param strides the strides between elements in a domain
82+
* @param distance the distance between batches in a domain
83+
* @param domain_str a string with the name of the domain being validated
84+
*/
85+
inline void validate_strides_distance_basic(const std::vector<std::size_t>& lengths, std::size_t number_of_transforms,
86+
const std::vector<std::size_t>& strides, std::size_t distance,
87+
const std::string_view domain_str) {
88+
// Validate stride
89+
std::size_t expected_num_strides = lengths.size();
90+
if (strides.size() != expected_num_strides) {
91+
throw invalid_configuration("Mismatching ", domain_str, " strides length got ", strides.size(), " expected ",
92+
expected_num_strides);
93+
}
94+
for (std::size_t i = 0; i < strides.size(); ++i) {
95+
if (strides[i] == 0) {
96+
throw invalid_configuration("Invalid ", domain_str, " stride[", i, "]=", strides[i], ", must be positive");
97+
}
98+
}
99+
100+
// Validate distance
101+
if (number_of_transforms > 1 && distance == 0) {
102+
throw invalid_configuration("Invalid ", domain_str, " distance ", distance, ", must be positive for batched FFTs");
103+
}
104+
}
105+
106+
/**
107+
* For multidimensional transforms, check that the strides are large enough so there will not be overlap within a single
108+
* batch. Throw when the strides are not big enough. This accounts for layouts like batch interleaved.
109+
*
110+
* @param lengths the dimensions of the transform
111+
* @param number_of_transforms the number of batches
112+
* @param strides the strides between elements in a domain
113+
* @param distance the distance between batches in a domain
114+
* @param domain_str a string with the name of the domain being validated
115+
*/
116+
inline void strides_distance_multidim_check(const std::vector<std::size_t>& lengths, std::size_t number_of_transforms,
117+
const std::vector<std::size_t>& strides, std::size_t distance,
118+
const std::string_view domain_str) {
119+
// Quick check for most common configurations.
120+
// This check has some false-negative for some impractical configurations.
121+
// View the output data as a N+1 dimensional tensor for a N-dimension FFT: the number of batch is just another
122+
// dimension with a stride of 'distance'. This sorts the dimensions from fastest moving (inner-most) to slowest
123+
// moving (outer-most) and check that the stride of a dimension is large enough to avoid overlapping the previous
124+
// dimension.
125+
std::vector<std::size_t> generic_strides = strides;
126+
std::vector<std::size_t> generic_sizes = lengths;
127+
if (number_of_transforms > 1) {
128+
generic_strides.push_back(distance);
129+
generic_sizes.push_back(number_of_transforms);
130+
}
131+
std::vector<std::size_t> indices(generic_sizes.size());
132+
std::iota(indices.begin(), indices.end(), 0);
133+
std::sort(indices.begin(), indices.end(),
134+
[&](std::size_t a, std::size_t b) { return generic_strides[a] < generic_strides[b]; });
135+
136+
for (std::size_t i = 1; i < indices.size(); ++i) {
137+
bool fits_in_next_dim =
138+
generic_strides[indices[i - 1]] * generic_sizes[indices[i - 1]] <= generic_strides[indices[i]];
139+
if (!fits_in_next_dim) {
140+
throw invalid_configuration("Domain ", domain_str,
141+
": multi-dimension strides are not large enough to avoid overlap");
142+
}
143+
}
144+
}
145+
146+
/**
147+
* Check that batches of 1D FFTs don't overlap.
148+
*
149+
* @param lengths the dimensions of the transform
150+
* @param number_of_transforms the number of batches
151+
* @param strides the strides between elements in a domain
152+
* @param distance the distance between batches in a domain
153+
* @param domain_str a string with the name of the domain being validated
154+
*/
155+
inline void strides_distance_1d_check(const std::vector<std::size_t>& lengths, std::size_t number_of_transforms,
156+
const std::vector<std::size_t>& strides, std::size_t distance,
157+
const std::string_view domain_str) {
158+
// It helps to think of the 1D transform layed out in 2D with row length of stride, that way each element of a
159+
// transform will be contiguous down a column.
160+
161+
// * If there is an index collision between batch N and batch N+M, then there will also be a collision between batch
162+
// N-1 and batch N+M-1, so if there is any index collision, there will also be one with batch 0 (batch N-N and batch
163+
// N+M-N).
164+
// * If an index in a transform mod the stride of the transform is zero, then it would collide with the first batch,
165+
// given an infinite FFT length. For all elements in a transforms, the index mod stride is the same.
166+
// * If an element in a batch index collides with another batch, then all previous elements in that batch will also
167+
// index collide with that batch, so we only need to check the first index of each batch.
168+
169+
const std::size_t fft_size = lengths[0];
170+
const std::size_t stride = strides[0];
171+
172+
const std::size_t first_batch_limit = stride * fft_size;
173+
const std::size_t first_length_limit = distance * number_of_transforms;
174+
if ((stride <= distance && first_batch_limit <= distance) || (distance <= stride && first_length_limit <= stride)) {
175+
return;
176+
}
177+
178+
for (std::size_t b = 1; b < number_of_transforms;) {
179+
std::size_t batch_first_idx = b * distance;
180+
auto column = batch_first_idx % stride;
181+
if (column == 0) { // there may be a collision with the first batch
182+
if (batch_first_idx >= first_batch_limit) {
183+
// any further batch will only be further way
184+
return;
185+
}
186+
throw invalid_configuration("Domain ", domain_str, ": batch ", b, " collides with first batch at index ",
187+
batch_first_idx);
188+
}
189+
190+
// there won't be another collision until the column number is near to stride again, so skip a few
191+
auto batches_until_new_column = (stride - column) / distance;
192+
if ((stride - column) % distance != 0) {
193+
batches_until_new_column += 1;
194+
}
195+
b += batches_until_new_column;
196+
}
197+
}
198+
199+
/**
200+
* Throw an exception if the given strides and distance are invalid for a single domain.
201+
*
202+
* @param lengths the dimensions of the transform
203+
* @param number_of_transforms the number of batches
204+
* @param strides the strides between elements in a domain
205+
* @param distance the distance between batches in a domain
206+
* @param domain_str a string with the name of the domain being validated
207+
*/
208+
inline void strides_distance_check(const std::vector<std::size_t>& lengths, std::size_t number_of_transforms,
209+
const std::vector<std::size_t>& strides, std::size_t distance,
210+
const std::string_view domain_str) {
211+
validate_strides_distance_basic(lengths, number_of_transforms, strides, distance, domain_str);
212+
if (lengths.size() > 1) {
213+
strides_distance_multidim_check(lengths, number_of_transforms, strides, distance, domain_str);
214+
} else {
215+
strides_distance_1d_check(lengths, number_of_transforms, strides, distance, domain_str);
216+
}
217+
}
218+
219+
/**
220+
* Throw an exception if the given strides and distances are invalid for either domain.
221+
*
222+
* @param place where the result is written with respect to where it is read (in-place vs not in-place)
223+
* @param lengths the dimensions of the transform
224+
* @param number_of_transforms the number of batches
225+
* @param forward_strides the strides between elements in the forward domain
226+
* @param backward_strides the strides between elements in the backward domain
227+
* @param forward_distance the distance between batches in the forward domain
228+
* @param backward_distance the distance between batches in the backward domain
229+
*/
230+
inline void validate_strides_distance(placement place, const std::vector<std::size_t>& lengths,
231+
std::size_t number_of_transforms, const std::vector<std::size_t>& forward_strides,
232+
const std::vector<std::size_t>& backward_strides, std::size_t forward_distance,
233+
std::size_t backward_distance) {
234+
if (place == placement::IN_PLACE) {
235+
if (forward_strides != backward_strides) {
236+
throw invalid_configuration("Invalid forward and backward strides must match for in-place configurations");
237+
}
238+
if (forward_distance != backward_distance) {
239+
throw invalid_configuration("Invalid forward and backward distances must match for in-place configurations");
240+
}
241+
strides_distance_check(lengths, number_of_transforms, forward_strides, forward_distance, "forward");
242+
} else {
243+
strides_distance_check(lengths, number_of_transforms, forward_strides, forward_distance, "forward");
244+
strides_distance_check(lengths, number_of_transforms, backward_strides, backward_distance, "backward");
245+
}
246+
}
247+
248+
/**
249+
* @brief Check as much as possible if a given descriptor is valid and supported for the current capabilties of portFFT.
250+
* @details The descriptor can still later be deemed unsupported if it is not immediately obvious. If the descriptor is
251+
* invalid, it should be reported here or not at all.
252+
*
253+
* @param params the final description of the problem.
254+
* @throws portfft::unsupported_configuration when the configuration is unsupported
255+
* @throws portfft::invalid_configuration when the configuration is invalid e.g. would cause elements to overlap
256+
*/
257+
template <typename Descriptor>
258+
void validate_descriptor(const Descriptor& params) {
259+
using namespace portfft;
260+
261+
if constexpr (Descriptor::Domain == domain::REAL) {
262+
throw unsupported_configuration("REAL domain is unsupported");
263+
}
264+
265+
if (params.number_of_transforms == 0) {
266+
throw invalid_configuration("Invalid number of transform ", params.number_of_transforms, ", must be positive");
267+
}
268+
269+
validate_lengths(params.lengths);
270+
validate_strides_distance(params.placement, params.lengths, params.number_of_transforms, params.forward_strides,
271+
params.backward_strides, params.forward_distance, params.backward_distance);
272+
validate_layout<typename Descriptor::Scalar>(params.lengths, portfft::detail::get_layout(params, direction::FORWARD),
273+
portfft::detail::get_layout(params, direction::BACKWARD));
274+
}
275+
276+
} // namespace portfft::detail::validate
277+
278+
#endif

src/portfft/dispatcher/global_dispatcher.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,19 +256,18 @@ template <typename Scalar, domain Domain>
256256
template <typename Dummy>
257257
struct committed_descriptor_impl<Scalar, Domain>::set_spec_constants_struct::inner<detail::level::GLOBAL, Dummy> {
258258
static void execute(committed_descriptor_impl& /*desc*/, sycl::kernel_bundle<sycl::bundle_state::input>& in_bundle,
259-
std::size_t length, const std::vector<Idx>& factors, detail::level level, Idx factor_num,
259+
Idx length, const std::vector<Idx>& factors, detail::level level, Idx factor_num,
260260
Idx num_factors) {
261261
PORTFFT_LOG_FUNCTION_ENTRY();
262-
Idx length_idx = static_cast<Idx>(length);
263262
PORTFFT_LOG_TRACE("GlobalSubImplSpecConst:", level);
264263
in_bundle.template set_specialization_constant<detail::GlobalSubImplSpecConst>(level);
265264
PORTFFT_LOG_TRACE("GlobalSpecConstNumFactors:", num_factors);
266265
in_bundle.template set_specialization_constant<detail::GlobalSpecConstNumFactors>(num_factors);
267266
PORTFFT_LOG_TRACE("GlobalSpecConstLevelNum:", factor_num);
268267
in_bundle.template set_specialization_constant<detail::GlobalSpecConstLevelNum>(factor_num);
269268
if (level == detail::level::WORKITEM || level == detail::level::WORKGROUP) {
270-
PORTFFT_LOG_TRACE("SpecConstFftSize:", length_idx);
271-
in_bundle.template set_specialization_constant<detail::SpecConstFftSize>(length_idx);
269+
PORTFFT_LOG_TRACE("SpecConstFftSize:", length);
270+
in_bundle.template set_specialization_constant<detail::SpecConstFftSize>(length);
272271
} else if (level == detail::level::SUBGROUP) {
273272
PORTFFT_LOG_TRACE("SubgroupFactorWISpecConst:", factors[1]);
274273
in_bundle.template set_specialization_constant<detail::SubgroupFactorWISpecConst>(factors[1]);

src/portfft/dispatcher/subgroup_dispatcher.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,8 @@ template <typename Scalar, domain Domain>
676676
template <typename Dummy>
677677
struct committed_descriptor_impl<Scalar, Domain>::set_spec_constants_struct::inner<detail::level::SUBGROUP, Dummy> {
678678
static void execute(committed_descriptor_impl& /*desc*/, sycl::kernel_bundle<sycl::bundle_state::input>& in_bundle,
679-
std::size_t /*length*/, const std::vector<Idx>& factors, detail::level /*level*/,
680-
Idx /*factor_num*/, Idx /*num_factors*/) {
679+
Idx /*length*/, const std::vector<Idx>& factors, detail::level /*level*/, Idx /*factor_num*/,
680+
Idx /*num_factors*/) {
681681
PORTFFT_LOG_FUNCTION_ENTRY();
682682
PORTFFT_LOG_TRACE("SubgroupFactorWISpecConst:", factors[0]);
683683
in_bundle.template set_specialization_constant<detail::SubgroupFactorWISpecConst>(factors[0]);

0 commit comments

Comments
 (0)