Skip to content

Commit 037fa6f

Browse files
committed
only try to prepare workitem/subgroup/workgroup impls for sizes that fit in local memory
1 parent bbfbbb8 commit 037fa6f

File tree

1 file changed

+36
-38
lines changed

1 file changed

+36
-38
lines changed

src/portfft/committed_descriptor_impl.hpp

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -216,46 +216,44 @@ class committed_descriptor_impl {
216216
throw unsupported_configuration("portFFT only supports complex to complex transforms");
217217
}
218218

219-
std::vector<sycl::kernel_id> ids;
220-
std::vector<Idx> factors;
221219
IdxGlobal fft_size = static_cast<IdxGlobal>(params.lengths[kernel_num]);
222-
if (detail::fits_in_wi<Scalar>(fft_size)) {
223-
ids = detail::get_ids<detail::workitem_kernel, Scalar, Domain, SubgroupSize>();
224-
PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size);
225-
return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}};
226-
}
227-
if (detail::fits_in_sg<Scalar>(fft_size, SubgroupSize)) {
228-
Idx factor_sg = detail::factorize_sg(static_cast<Idx>(fft_size), SubgroupSize);
229-
Idx factor_wi = static_cast<Idx>(fft_size) / factor_sg;
230-
// This factorization is duplicated in the dispatch logic on the device.
231-
// The CT and spec constant factors should match.
232-
factors.push_back(factor_wi);
233-
factors.push_back(factor_sg);
234-
ids = detail::get_ids<detail::subgroup_kernel, Scalar, Domain, SubgroupSize>();
235-
PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg);
236-
return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}};
237-
}
238-
if (auto wg_factorization = detail::factorize_for_wg<Scalar>(fft_size, SubgroupSize); wg_factorization) {
239-
auto [factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m] = wg_factorization.value();
240-
Idx temp_num_sgs_in_wg;
241-
std::size_t local_memory_usage =
242-
num_scalars_in_local_mem(detail::level::WORKGROUP, static_cast<std::size_t>(fft_size), SubgroupSize,
243-
{factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg,
244-
layout::PACKED) *
245-
sizeof(Scalar);
246-
// Checks for PACKED layout only at the moment, as the other layout will not be supported
247-
// by the global implementation. For such sizes, only PACKED layout will be supported
248-
if (local_memory_usage <= static_cast<std::size_t>(local_memory_size)) {
249-
factors.push_back(factor_wi_n);
250-
factors.push_back(factor_sg_n);
251-
factors.push_back(factor_wi_m);
252-
factors.push_back(factor_sg_m);
253-
// This factorization of N and M is duplicated in the dispatch logic on the device.
220+
if (static_cast<size_t>(fft_size) * 2 * sizeof(Scalar) <= static_cast<size_t>(local_memory_size)) {
221+
// These implementations only work if the size fits in local memory.
222+
// They still may not be suitable if the extra local memory needed for the algorithm exceeds the available memory.
223+
224+
if (detail::fits_in_wi<Scalar>(fft_size)) {
225+
auto ids = detail::get_ids<detail::workitem_kernel, Scalar, Domain, SubgroupSize>();
226+
PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size);
227+
return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, {}}}};
228+
}
229+
if (detail::fits_in_sg<Scalar>(fft_size, SubgroupSize)) {
230+
Idx factor_sg = detail::factorize_sg(static_cast<Idx>(fft_size), SubgroupSize);
231+
Idx factor_wi = static_cast<Idx>(fft_size) / factor_sg;
232+
// This factorization is duplicated in the dispatch logic on the device.
254233
// The CT and spec constant factors should match.
255-
ids = detail::get_ids<detail::workgroup_kernel, Scalar, Domain, SubgroupSize>();
256-
PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n,
257-
" factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m);
258-
return {detail::level::WORKGROUP, {{detail::level::WORKGROUP, ids, factors}}};
234+
auto ids = detail::get_ids<detail::subgroup_kernel, Scalar, Domain, SubgroupSize>();
235+
PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg);
236+
return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, {factor_wi, factor_sg}}}};
237+
}
238+
if (auto wg_factorization = detail::factorize_for_wg<Scalar>(fft_size, SubgroupSize); wg_factorization) {
239+
auto [factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m] = wg_factorization.value();
240+
Idx temp_num_sgs_in_wg;
241+
std::size_t local_memory_usage =
242+
num_scalars_in_local_mem(detail::level::WORKGROUP, static_cast<std::size_t>(fft_size), SubgroupSize,
243+
{factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg,
244+
layout::PACKED) *
245+
sizeof(Scalar);
246+
// Checks for PACKED layout only at the moment, as the other layout will not be supported
247+
// by the global implementation. For such sizes, only PACKED layout will be supported
248+
if (local_memory_usage <= static_cast<std::size_t>(local_memory_size)) {
249+
// This factorization of N and M is duplicated in the dispatch logic on the device.
250+
// The CT and spec constant factors should match.
251+
auto ids = detail::get_ids<detail::workgroup_kernel, Scalar, Domain, SubgroupSize>();
252+
PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n,
253+
" factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m);
254+
return {detail::level::WORKGROUP,
255+
{{detail::level::WORKGROUP, ids, {factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m}}}};
256+
}
259257
}
260258
}
261259
PORTFFT_LOG_TRACE("Preparing global impl");

0 commit comments

Comments
 (0)