|
| 1 | +/*************************************************************************** |
| 2 | + * |
| 3 | + * Copyright (C) Codeplay Software Ltd. |
| 4 | + * |
| 5 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | + * you may not use this file except in compliance with the License. |
| 7 | + * You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + * |
| 17 | + * Codeplay's portFFT |
| 18 | + * |
| 19 | + **************************************************************************/ |
| 20 | + |
| 21 | +#ifndef PORTFFT_DESCRIPTOR_VALIDATE_HPP |
| 22 | +#define PORTFFT_DESCRIPTOR_VALIDATE_HPP |
| 23 | + |
| 24 | +#include <string_view> |
| 25 | + |
| 26 | +#include "common/exceptions.hpp" |
| 27 | +#include "common/workitem.hpp" |
| 28 | +#include "enums.hpp" |
| 29 | +#include "utils.hpp" |
| 30 | + |
| 31 | +namespace portfft::detail::validate { |
| 32 | + |
| 33 | +/** |
| 34 | + * Throw an exception if the lengths are invalid when looked at in isolation. |
| 35 | + * |
| 36 | + * @param lengths the dimensions of the transform |
| 37 | + */ |
| 38 | +inline void validate_lengths(const std::vector<std::size_t>& lengths) { |
| 39 | + if (lengths.empty()) { |
| 40 | + throw invalid_configuration("Invalid lengths, must have at least 1 dimension"); |
| 41 | + } |
| 42 | + for (std::size_t i = 0; i < lengths.size(); ++i) { |
| 43 | + if (lengths[i] == 0) { |
| 44 | + throw invalid_configuration("Invalid lengths[", i, "]=", lengths[i], ", must be positive"); |
| 45 | + } |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +/** |
| 50 | + * Throw an exception if the layout is unsupported. |
| 51 | + * |
| 52 | + * @tparam Scalar the scalar type for the transform |
| 53 | + * @param lengths the dimensions of the transform |
| 54 | + * @param forward_layout the layout of the forward domain |
| 55 | + * @param backward_layout the layout of the backward domain |
| 56 | + */ |
| 57 | +template <typename Scalar> |
| 58 | +inline void validate_layout(const std::vector<std::size_t>& lengths, portfft::detail::layout forward_layout, |
| 59 | + portfft::detail::layout backward_layout) { |
| 60 | + if (lengths.size() > 1) { |
| 61 | + const bool supported_layout = |
| 62 | + forward_layout == portfft::detail::layout::PACKED && backward_layout == portfft::detail::layout::PACKED; |
| 63 | + if (!supported_layout) { |
| 64 | + throw unsupported_configuration("Multi-dimensional transforms are only supported with default data layout"); |
| 65 | + } |
| 66 | + } |
| 67 | + if (forward_layout == portfft::detail::layout::UNPACKED || backward_layout == portfft::detail::layout::UNPACKED) { |
| 68 | + if (!portfft::detail::fits_in_wi<Scalar>(lengths.back())) { |
| 69 | + throw unsupported_configuration( |
| 70 | + "Arbitrary strides and distances are only supported for sizes that fit in the registers of a single " |
| 71 | + "work-item"); |
| 72 | + } |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +/** |
| 77 | + * Throw an exception if individual stride, distance and number_of_transforms values are invalid/inconsistent. |
| 78 | + * |
| 79 | + * @param lengths the dimensions of the transform |
| 80 | + * @param number_of_transforms the number of batches |
| 81 | + * @param strides the strides between elements in a domain |
| 82 | + * @param distance the distance between batches in a domain |
| 83 | + * @param domain_str a string with the name of the domain being validated |
| 84 | + */ |
| 85 | +inline void validate_strides_distance_basic(const std::vector<std::size_t>& lengths, std::size_t number_of_transforms, |
| 86 | + const std::vector<std::size_t>& strides, std::size_t distance, |
| 87 | + const std::string_view domain_str) { |
| 88 | + // Validate stride |
| 89 | + std::size_t expected_num_strides = lengths.size(); |
| 90 | + if (strides.size() != expected_num_strides) { |
| 91 | + throw invalid_configuration("Mismatching ", domain_str, " strides length got ", strides.size(), " expected ", |
| 92 | + expected_num_strides); |
| 93 | + } |
| 94 | + for (std::size_t i = 0; i < strides.size(); ++i) { |
| 95 | + if (strides[i] == 0) { |
| 96 | + throw invalid_configuration("Invalid ", domain_str, " stride[", i, "]=", strides[i], ", must be positive"); |
| 97 | + } |
| 98 | + } |
| 99 | + |
| 100 | + // Validate distance |
| 101 | + if (number_of_transforms > 1 && distance == 0) { |
| 102 | + throw invalid_configuration("Invalid ", domain_str, " distance ", distance, ", must be positive for batched FFTs"); |
| 103 | + } |
| 104 | +} |
| 105 | + |
| 106 | +/** |
| 107 | + * For multidimensional transforms, check that the strides are large enough so there will not be overlap within a single |
| 108 | + * batch. Throw when the strides are not big enough. This accounts for layouts like batch interleaved. |
| 109 | + * |
| 110 | + * @param lengths the dimensions of the transform |
| 111 | + * @param number_of_transforms the number of batches |
| 112 | + * @param strides the strides between elements in a domain |
| 113 | + * @param distance the distance between batches in a domain |
| 114 | + * @param domain_str a string with the name of the domain being validated |
| 115 | + */ |
| 116 | +inline void strides_distance_multidim_check(const std::vector<std::size_t>& lengths, std::size_t number_of_transforms, |
| 117 | + const std::vector<std::size_t>& strides, std::size_t distance, |
| 118 | + const std::string_view domain_str) { |
| 119 | + // Quick check for most common configurations. |
| 120 | + // This check has some false-negative for some impractical configurations. |
| 121 | + // View the output data as a N+1 dimensional tensor for a N-dimension FFT: the number of batch is just another |
| 122 | + // dimension with a stride of 'distance'. This sorts the dimensions from fastest moving (inner-most) to slowest |
| 123 | + // moving (outer-most) and check that the stride of a dimension is large enough to avoid overlapping the previous |
| 124 | + // dimension. |
| 125 | + std::vector<std::size_t> generic_strides = strides; |
| 126 | + std::vector<std::size_t> generic_sizes = lengths; |
| 127 | + if (number_of_transforms > 1) { |
| 128 | + generic_strides.push_back(distance); |
| 129 | + generic_sizes.push_back(number_of_transforms); |
| 130 | + } |
| 131 | + std::vector<std::size_t> indices(generic_sizes.size()); |
| 132 | + std::iota(indices.begin(), indices.end(), 0); |
| 133 | + std::sort(indices.begin(), indices.end(), |
| 134 | + [&](std::size_t a, std::size_t b) { return generic_strides[a] < generic_strides[b]; }); |
| 135 | + |
| 136 | + for (std::size_t i = 1; i < indices.size(); ++i) { |
| 137 | + bool fits_in_next_dim = |
| 138 | + generic_strides[indices[i - 1]] * generic_sizes[indices[i - 1]] <= generic_strides[indices[i]]; |
| 139 | + if (!fits_in_next_dim) { |
| 140 | + throw invalid_configuration("Domain ", domain_str, |
| 141 | + ": multi-dimension strides are not large enough to avoid overlap"); |
| 142 | + } |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | +/** |
| 147 | + * Check that batches of 1D FFTs don't overlap. |
| 148 | + * |
| 149 | + * @param lengths the dimensions of the transform |
| 150 | + * @param number_of_transforms the number of batches |
| 151 | + * @param strides the strides between elements in a domain |
| 152 | + * @param distance the distance between batches in a domain |
| 153 | + * @param domain_str a string with the name of the domain being validated |
| 154 | + */ |
| 155 | +inline void strides_distance_1d_check(const std::vector<std::size_t>& lengths, std::size_t number_of_transforms, |
| 156 | + const std::vector<std::size_t>& strides, std::size_t distance, |
| 157 | + const std::string_view domain_str) { |
| 158 | + // It helps to think of the 1D transform layed out in 2D with row length of stride, that way each element of a |
| 159 | + // transform will be contiguous down a column. |
| 160 | + |
| 161 | + // * If there is an index collision between batch N and batch N+M, then there will also be a collision between batch |
| 162 | + // N-1 and batch N+M-1, so if there is any index collision, there will also be one with batch 0 (batch N-N and batch |
| 163 | + // N+M-N). |
| 164 | + // * If an index in a transform mod the stride of the transform is zero, then it would collide with the first batch, |
| 165 | + // given an infinite FFT length. For all elements in a transforms, the index mod stride is the same. |
| 166 | + // * If an element in a batch index collides with another batch, then all previous elements in that batch will also |
| 167 | + // index collide with that batch, so we only need to check the first index of each batch. |
| 168 | + |
| 169 | + const std::size_t fft_size = lengths[0]; |
| 170 | + const std::size_t stride = strides[0]; |
| 171 | + |
| 172 | + const std::size_t first_batch_limit = stride * fft_size; |
| 173 | + const std::size_t first_length_limit = distance * number_of_transforms; |
| 174 | + if ((stride <= distance && first_batch_limit <= distance) || (distance <= stride && first_length_limit <= stride)) { |
| 175 | + return; |
| 176 | + } |
| 177 | + |
| 178 | + for (std::size_t b = 1; b < number_of_transforms;) { |
| 179 | + std::size_t batch_first_idx = b * distance; |
| 180 | + auto column = batch_first_idx % stride; |
| 181 | + if (column == 0) { // there may be a collision with the first batch |
| 182 | + if (batch_first_idx >= first_batch_limit) { |
| 183 | + // any further batch will only be further way |
| 184 | + return; |
| 185 | + } |
| 186 | + throw invalid_configuration("Domain ", domain_str, ": batch ", b, " collides with first batch at index ", |
| 187 | + batch_first_idx); |
| 188 | + } |
| 189 | + |
| 190 | + // there won't be another collision until the column number is near to stride again, so skip a few |
| 191 | + auto batches_until_new_column = (stride - column) / distance; |
| 192 | + if ((stride - column) % distance != 0) { |
| 193 | + batches_until_new_column += 1; |
| 194 | + } |
| 195 | + b += batches_until_new_column; |
| 196 | + } |
| 197 | +} |
| 198 | + |
| 199 | +/** |
| 200 | + * Throw an exception if the given strides and distance are invalid for a single domain. |
| 201 | + * |
| 202 | + * @param lengths the dimensions of the transform |
| 203 | + * @param number_of_transforms the number of batches |
| 204 | + * @param strides the strides between elements in a domain |
| 205 | + * @param distance the distance between batches in a domain |
| 206 | + * @param domain_str a string with the name of the domain being validated |
| 207 | + */ |
| 208 | +inline void strides_distance_check(const std::vector<std::size_t>& lengths, std::size_t number_of_transforms, |
| 209 | + const std::vector<std::size_t>& strides, std::size_t distance, |
| 210 | + const std::string_view domain_str) { |
| 211 | + validate_strides_distance_basic(lengths, number_of_transforms, strides, distance, domain_str); |
| 212 | + if (lengths.size() > 1) { |
| 213 | + strides_distance_multidim_check(lengths, number_of_transforms, strides, distance, domain_str); |
| 214 | + } else { |
| 215 | + strides_distance_1d_check(lengths, number_of_transforms, strides, distance, domain_str); |
| 216 | + } |
| 217 | +} |
| 218 | + |
| 219 | +/** |
| 220 | + * Throw an exception if the given strides and distances are invalid for either domain. |
| 221 | + * |
| 222 | + * @param place where the result is written with respect to where it is read (in-place vs not in-place) |
| 223 | + * @param lengths the dimensions of the transform |
| 224 | + * @param number_of_transforms the number of batches |
| 225 | + * @param forward_strides the strides between elements in the forward domain |
| 226 | + * @param backward_strides the strides between elements in the backward domain |
| 227 | + * @param forward_distance the distance between batches in the forward domain |
| 228 | + * @param backward_distance the distance between batches in the backward domain |
| 229 | + */ |
| 230 | +inline void validate_strides_distance(placement place, const std::vector<std::size_t>& lengths, |
| 231 | + std::size_t number_of_transforms, const std::vector<std::size_t>& forward_strides, |
| 232 | + const std::vector<std::size_t>& backward_strides, std::size_t forward_distance, |
| 233 | + std::size_t backward_distance) { |
| 234 | + if (place == placement::IN_PLACE) { |
| 235 | + if (forward_strides != backward_strides) { |
| 236 | + throw invalid_configuration("Invalid forward and backward strides must match for in-place configurations"); |
| 237 | + } |
| 238 | + if (forward_distance != backward_distance) { |
| 239 | + throw invalid_configuration("Invalid forward and backward distances must match for in-place configurations"); |
| 240 | + } |
| 241 | + strides_distance_check(lengths, number_of_transforms, forward_strides, forward_distance, "forward"); |
| 242 | + } else { |
| 243 | + strides_distance_check(lengths, number_of_transforms, forward_strides, forward_distance, "forward"); |
| 244 | + strides_distance_check(lengths, number_of_transforms, backward_strides, backward_distance, "backward"); |
| 245 | + } |
| 246 | +} |
| 247 | + |
| 248 | +/** |
| 249 | + * @brief Check as much as possible if a given descriptor is valid and supported for the current capabilties of portFFT. |
| 250 | + * @details The descriptor can still later be deemed unsupported if it is not immediately obvious. If the descriptor is |
| 251 | + * invalid, it should be reported here or not at all. |
| 252 | + * |
| 253 | + * @param params the final description of the problem. |
| 254 | + * @throws portfft::unsupported_configuration when the configuration is unsupported |
| 255 | + * @throws portfft::invalid_configuration when the configuration is invalid e.g. would cause elements to overlap |
| 256 | + */ |
| 257 | +template <typename Descriptor> |
| 258 | +void validate_descriptor(const Descriptor& params) { |
| 259 | + using namespace portfft; |
| 260 | + |
| 261 | + if constexpr (Descriptor::Domain == domain::REAL) { |
| 262 | + throw unsupported_configuration("REAL domain is unsupported"); |
| 263 | + } |
| 264 | + |
| 265 | + if (params.number_of_transforms == 0) { |
| 266 | + throw invalid_configuration("Invalid number of transform ", params.number_of_transforms, ", must be positive"); |
| 267 | + } |
| 268 | + |
| 269 | + validate_lengths(params.lengths); |
| 270 | + validate_strides_distance(params.placement, params.lengths, params.number_of_transforms, params.forward_strides, |
| 271 | + params.backward_strides, params.forward_distance, params.backward_distance); |
| 272 | + validate_layout<typename Descriptor::Scalar>(params.lengths, portfft::detail::get_layout(params, direction::FORWARD), |
| 273 | + portfft::detail::get_layout(params, direction::BACKWARD)); |
| 274 | +} |
| 275 | + |
| 276 | +} // namespace portfft::detail::validate |
| 277 | + |
| 278 | +#endif |
0 commit comments