Skip to content

Commit 012c478

Browse files
committed
Workgroup strided transforms
1 parent f859108 commit 012c478

File tree

6 files changed

+182
-110
lines changed

6 files changed

+182
-110
lines changed

src/portfft/committed_descriptor_impl.hpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "common/exceptions.hpp"
3333
#include "common/subgroup.hpp"
34+
#include "common/workgroup.hpp"
3435
#include "defines.hpp"
3536
#include "enums.hpp"
3637
#include "specialization_constant.hpp"
@@ -234,18 +235,8 @@ class committed_descriptor_impl {
234235
PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg);
235236
return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}};
236237
}
237-
IdxGlobal n_idx_global = detail::factorize(fft_size);
238-
if (detail::can_cast_safely<IdxGlobal, Idx>(n_idx_global) &&
239-
detail::can_cast_safely<IdxGlobal, Idx>(fft_size / n_idx_global)) {
240-
if (n_idx_global == 1) {
241-
throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported");
242-
}
243-
Idx n = static_cast<Idx>(n_idx_global);
244-
Idx m = static_cast<Idx>(fft_size / n_idx_global);
245-
Idx factor_sg_n = detail::factorize_sg(n, SubgroupSize);
246-
Idx factor_wi_n = n / factor_sg_n;
247-
Idx factor_sg_m = detail::factorize_sg(m, SubgroupSize);
248-
Idx factor_wi_m = m / factor_sg_m;
238+
if (auto wg_factorization = detail::factorize_for_wg<Scalar>(fft_size, SubgroupSize); wg_factorization) {
239+
auto [factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m] = wg_factorization.value();
249240
Idx temp_num_sgs_in_wg;
250241
std::size_t local_memory_usage =
251242
num_scalars_in_local_mem(detail::level::WORKGROUP, static_cast<std::size_t>(fft_size), SubgroupSize,
@@ -254,8 +245,7 @@ class committed_descriptor_impl {
254245
sizeof(Scalar);
255246
// Checks for PACKED layout only at the moment, as the other layout will not be supported
256247
// by the global implementation. For such sizes, only PACKED layout will be supported
257-
if (detail::fits_in_wi<Scalar>(factor_wi_n) && detail::fits_in_wi<Scalar>(factor_wi_m) &&
258-
(local_memory_usage <= static_cast<std::size_t>(local_memory_size))) {
248+
if (local_memory_usage <= static_cast<std::size_t>(local_memory_size)) {
259249
factors.push_back(factor_wi_n);
260250
factors.push_back(factor_sg_n);
261251
factors.push_back(factor_wi_m);

src/portfft/common/workgroup.hpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@
2121
#ifndef PORTFFT_COMMON_WORKGROUP_HPP
2222
#define PORTFFT_COMMON_WORKGROUP_HPP
2323

24+
#include <optional>
25+
2426
#include "helpers.hpp"
2527
#include "logging.hpp"
28+
#include "memory_views.hpp"
2629
#include "portfft/defines.hpp"
2730
#include "portfft/enums.hpp"
2831
#include "portfft/traits.hpp"
32+
#include "portfft/utils.hpp"
2933
#include "subgroup.hpp"
34+
#include "transfers.hpp"
3035

3136
namespace portfft {
3237

@@ -53,6 +58,48 @@ constexpr T bank_lines_per_pad_wg(T row_size) {
5358
}
5459

5560
namespace detail {
61+
62+
// struct for the result of factorize_for_wg
63+
struct wg_factorization {
64+
Idx factor_wi_n;
65+
Idx factor_sg_n;
66+
Idx factor_wi_m;
67+
Idx factor_sg_m;
68+
};
69+
70+
/**
71+
*
72+
* Calculate a valid factorization for workgroup dfts, assuming there is sufficient local memory.
73+
*
74+
* @tparam Scalar scalar type of the transform data
75+
* @param fft_size the number of elements in the transforms
76+
* @param subgroup_size the size of subgroup used for the transform
77+
*
78+
* @return a factorization for workgroup dft or null if the size won't work with the implemenation of workgroup dfts.
79+
*/
80+
template <typename Scalar>
81+
inline std::optional<wg_factorization> factorize_for_wg(IdxGlobal fft_size, Idx subgroup_size) {
82+
IdxGlobal n_idx_global = detail::factorize(fft_size);
83+
if (n_idx_global == 1) {
84+
throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported");
85+
}
86+
IdxGlobal m_idx_global = fft_size / n_idx_global;
87+
if (detail::can_cast_safely<IdxGlobal, Idx>(n_idx_global) && detail::can_cast_safely<IdxGlobal, Idx>(m_idx_global)) {
88+
Idx n = static_cast<Idx>(n_idx_global);
89+
Idx m = static_cast<Idx>(m_idx_global);
90+
Idx factor_sg_n = detail::factorize_sg(n, subgroup_size);
91+
Idx factor_wi_n = n / factor_sg_n;
92+
Idx factor_sg_m = detail::factorize_sg(m, subgroup_size);
93+
Idx factor_wi_m = m / factor_sg_m;
94+
95+
if (fits_in_wi<Scalar>(factor_wi_n) && fits_in_wi<Scalar>(factor_wi_m)) {
96+
return wg_factorization{factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m};
97+
}
98+
}
99+
100+
return std::nullopt;
101+
}
102+
56103
/**
57104
* Calculate all dfts in one dimension of the data stored in local memory.
58105
*

src/portfft/descriptor_validation.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include <string_view>
2525

2626
#include "common/exceptions.hpp"
27-
#include "common/subgroup.hpp"
27+
#include "common/workgroup.hpp"
2828
#include "enums.hpp"
2929
#include "utils.hpp"
3030

@@ -68,7 +68,7 @@ inline void validate_layout(const std::vector<std::size_t>& lengths, portfft::de
6868
bool fits_subgroup = false;
6969
for (auto sg_size : {PORTFFT_SUBGROUP_SIZES}) {
7070
fits_subgroup =
71-
fits_subgroup || portfft::detail::fits_in_sg<Scalar>(static_cast<IdxGlobal>(lengths.back()), sg_size);
71+
fits_subgroup || portfft::detail::factorize_for_wg<Scalar>(static_cast<IdxGlobal>(lengths.back()), sg_size);
7272
if (fits_subgroup) {
7373
break;
7474
}

0 commit comments

Comments
 (0)