3131
3232#include " common/exceptions.hpp"
3333#include " common/subgroup.hpp"
34+ #include " common/workgroup.hpp"
3435#include " defines.hpp"
3536#include " enums.hpp"
3637#include " specialization_constant.hpp"
@@ -234,18 +235,8 @@ class committed_descriptor_impl {
234235 PORTFFT_LOG_TRACE (" Prepared subgroup impl with factor_wi:" , factor_wi, " and factor_sg:" , factor_sg);
235236 return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}};
236237 }
237- IdxGlobal n_idx_global = detail::factorize (fft_size);
238- if (detail::can_cast_safely<IdxGlobal, Idx>(n_idx_global) &&
239- detail::can_cast_safely<IdxGlobal, Idx>(fft_size / n_idx_global)) {
240- if (n_idx_global == 1 ) {
241- throw unsupported_configuration (" FFT size " , fft_size, " : Large Prime sized FFT currently is unsupported" );
242- }
243- Idx n = static_cast <Idx>(n_idx_global);
244- Idx m = static_cast <Idx>(fft_size / n_idx_global);
245- Idx factor_sg_n = detail::factorize_sg (n, SubgroupSize);
246- Idx factor_wi_n = n / factor_sg_n;
247- Idx factor_sg_m = detail::factorize_sg (m, SubgroupSize);
248- Idx factor_wi_m = m / factor_sg_m;
238+ if (auto wg_factorization = detail::factorize_for_wg<Scalar>(fft_size, SubgroupSize); wg_factorization) {
239+ auto [factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m] = wg_factorization.value ();
249240 Idx temp_num_sgs_in_wg;
250241 std::size_t local_memory_usage =
251242 num_scalars_in_local_mem (detail::level::WORKGROUP, static_cast <std::size_t >(fft_size), SubgroupSize,
@@ -254,8 +245,7 @@ class committed_descriptor_impl {
254245 sizeof (Scalar);
255246 // Checks for PACKED layout only at the moment, as the other layout will not be supported
256247 // by the global implementation. For such sizes, only PACKED layout will be supported
257- if (detail::fits_in_wi<Scalar>(factor_wi_n) && detail::fits_in_wi<Scalar>(factor_wi_m) &&
258- (local_memory_usage <= static_cast <std::size_t >(local_memory_size))) {
248+ if (local_memory_usage <= static_cast <std::size_t >(local_memory_size)) {
259249 factors.push_back (factor_wi_n);
260250 factors.push_back (factor_sg_n);
261251 factors.push_back (factor_wi_m);
0 commit comments