@@ -216,46 +216,44 @@ class committed_descriptor_impl {
216216 throw unsupported_configuration (" portFFT only supports complex to complex transforms" );
217217 }
218218
219- std::vector<sycl::kernel_id> ids;
220- std::vector<Idx> factors;
221219 IdxGlobal fft_size = static_cast <IdxGlobal>(params.lengths [kernel_num]);
222- if (detail::fits_in_wi<Scalar>(fft_size)) {
223- ids = detail::get_ids<detail::workitem_kernel, Scalar, Domain, SubgroupSize>();
224- PORTFFT_LOG_TRACE (" Prepared workitem impl for size: " , fft_size);
225- return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}};
226- }
227- if (detail::fits_in_sg<Scalar>(fft_size, SubgroupSize)) {
228- Idx factor_sg = detail::factorize_sg (static_cast <Idx>(fft_size), SubgroupSize);
229- Idx factor_wi = static_cast <Idx>(fft_size) / factor_sg;
230- // This factorization is duplicated in the dispatch logic on the device.
231- // The CT and spec constant factors should match.
232- factors.push_back (factor_wi);
233- factors.push_back (factor_sg);
234- ids = detail::get_ids<detail::subgroup_kernel, Scalar, Domain, SubgroupSize>();
235- PORTFFT_LOG_TRACE (" Prepared subgroup impl with factor_wi:" , factor_wi, " and factor_sg:" , factor_sg);
236- return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}};
237- }
238- if (auto wg_factorization = detail::factorize_for_wg<Scalar>(fft_size, SubgroupSize); wg_factorization) {
239- auto [factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m] = wg_factorization.value ();
240- Idx temp_num_sgs_in_wg;
241- std::size_t local_memory_usage =
242- num_scalars_in_local_mem (detail::level::WORKGROUP, static_cast <std::size_t >(fft_size), SubgroupSize,
243- {factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg,
244- layout::PACKED) *
245- sizeof (Scalar);
246- // Checks for PACKED layout only at the moment, as the other layout will not be supported
247- // by the global implementation. For such sizes, only PACKED layout will be supported
248- if (local_memory_usage <= static_cast <std::size_t >(local_memory_size)) {
249- factors.push_back (factor_wi_n);
250- factors.push_back (factor_sg_n);
251- factors.push_back (factor_wi_m);
252- factors.push_back (factor_sg_m);
253- // This factorization of N and M is duplicated in the dispatch logic on the device.
220+ if (static_cast <size_t >(fft_size) * 2 * sizeof (Scalar) <= static_cast <size_t >(local_memory_size)) {
221+ // These implementations only work if the size fits in local memory.
222+ // They still may not be suitable if the extra local memory needed for the algorithm exceeds the available memory.
223+
224+ if (detail::fits_in_wi<Scalar>(fft_size)) {
225+ auto ids = detail::get_ids<detail::workitem_kernel, Scalar, Domain, SubgroupSize>();
226+ PORTFFT_LOG_TRACE (" Prepared workitem impl for size: " , fft_size);
227+ return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, {}}}};
228+ }
229+ if (detail::fits_in_sg<Scalar>(fft_size, SubgroupSize)) {
230+ Idx factor_sg = detail::factorize_sg (static_cast <Idx>(fft_size), SubgroupSize);
231+ Idx factor_wi = static_cast <Idx>(fft_size) / factor_sg;
232+ // This factorization is duplicated in the dispatch logic on the device.
254233 // The CT and spec constant factors should match.
255- ids = detail::get_ids<detail::workgroup_kernel, Scalar, Domain, SubgroupSize>();
256- PORTFFT_LOG_TRACE (" Prepared workgroup impl with factor_wi_n:" , factor_wi_n, " factor_sg_n:" , factor_sg_n,
257- " factor_wi_m:" , factor_wi_m, " factor_sg_m:" , factor_sg_m);
258- return {detail::level::WORKGROUP, {{detail::level::WORKGROUP, ids, factors}}};
234+ auto ids = detail::get_ids<detail::subgroup_kernel, Scalar, Domain, SubgroupSize>();
235+ PORTFFT_LOG_TRACE (" Prepared subgroup impl with factor_wi:" , factor_wi, " and factor_sg:" , factor_sg);
236+ return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, {factor_wi, factor_sg}}}};
237+ }
238+ if (auto wg_factorization = detail::factorize_for_wg<Scalar>(fft_size, SubgroupSize); wg_factorization) {
239+ auto [factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m] = wg_factorization.value ();
240+ Idx temp_num_sgs_in_wg;
241+ std::size_t local_memory_usage =
242+ num_scalars_in_local_mem (detail::level::WORKGROUP, static_cast <std::size_t >(fft_size), SubgroupSize,
243+ {factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg,
244+ layout::PACKED) *
245+ sizeof (Scalar);
246+ // Checks for PACKED layout only at the moment, as the other layout will not be supported
247+ // by the global implementation. For such sizes, only PACKED layout will be supported
248+ if (local_memory_usage <= static_cast <std::size_t >(local_memory_size)) {
249+ // This factorization of N and M is duplicated in the dispatch logic on the device.
250+ // The CT and spec constant factors should match.
251+ auto ids = detail::get_ids<detail::workgroup_kernel, Scalar, Domain, SubgroupSize>();
252+ PORTFFT_LOG_TRACE (" Prepared workgroup impl with factor_wi_n:" , factor_wi_n, " factor_sg_n:" , factor_sg_n,
253+ " factor_wi_m:" , factor_wi_m, " factor_sg_m:" , factor_sg_m);
254+ return {detail::level::WORKGROUP,
255+ {{detail::level::WORKGROUP, ids, {factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m}}}};
256+ }
259257 }
260258 }
261259 PORTFFT_LOG_TRACE (" Preparing global impl" );
0 commit comments