Skip to content

Commit 330f37b

Browse files
authored
Remove layout templates from src (#141)
1 parent 324a88f commit 330f37b

File tree

9 files changed

+230
-312
lines changed

9 files changed

+230
-312
lines changed

src/portfft/committed_descriptor_impl.hpp

Lines changed: 56 additions & 115 deletions
Large diffs are not rendered by default.

src/portfft/common/global.hpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors,
115115
* Device function responsible for calling the corresponding sub-implementation
116116
*
117117
* @tparam Scalar Scalar type
118-
* @tparam LayoutIn Input layout
119-
* @tparam LayoutOut Output layout
120118
* @tparam SubgroupSize Subgroup size
121119
* @param input input pointer
122120
* @param output output pointer
@@ -134,7 +132,7 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors,
134132
* @param global_data global data
135133
* @param kh kernel handler
136134
*/
137-
template <typename Scalar, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize>
135+
template <typename Scalar, Idx SubgroupSize>
138136
PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Scalar* input_imag, Scalar* output_imag,
139137
const Scalar* implementation_twiddles, const Scalar* store_modifier_data,
140138
Scalar* input_loc, Scalar* twiddles_loc, Scalar* store_modifier_loc,
@@ -156,16 +154,16 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc
156154
batch_size, global_data, kh, static_cast<const Scalar*>(nullptr),
157155
store_modifier_data, static_cast<Scalar*>(nullptr), store_modifier_loc);
158156
} else if (level == detail::level::SUBGROUP) {
159-
subgroup_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
160-
input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,
161-
output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data,
162-
kh, static_cast<const Scalar*>(nullptr), store_modifier_data, static_cast<Scalar*>(nullptr),
163-
store_modifier_loc);
157+
subgroup_impl<SubgroupSize, Scalar>(input + outer_batch_offset, output + outer_batch_offset,
158+
input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc,
159+
twiddles_loc, batch_size, implementation_twiddles, global_data, kh,
160+
static_cast<const Scalar*>(nullptr), store_modifier_data,
161+
static_cast<Scalar*>(nullptr), store_modifier_loc);
164162
} else if (level == detail::level::WORKGROUP) {
165-
workgroup_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
166-
input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,
167-
output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data,
168-
kh, static_cast<Scalar*>(nullptr), store_modifier_data);
163+
workgroup_impl<SubgroupSize, Scalar>(input + outer_batch_offset, output + outer_batch_offset,
164+
input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc,
165+
twiddles_loc, batch_size, implementation_twiddles, global_data, kh,
166+
static_cast<Scalar*>(nullptr), store_modifier_data);
169167
}
170168
sycl::group_barrier(global_data.it.get_group());
171169
}
@@ -277,8 +275,6 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
277275
* Prepares the launch of fft compute at a particular level
278276
* @tparam Scalar Scalar type
279277
* @tparam Domain Domain of FFT
280-
* @tparam LayoutIn Input layout
281-
* @tparam LayoutOut output layout
282278
* @tparam SubgroupSize subgroup size
283279
* @tparam TIn input type
284280
* @param kd_struct associated kernel data struct with the factor
@@ -304,8 +300,7 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
304300
* @param queue queue
305301
* @return vector events, one for each batch in l2
306302
*/
307-
template <typename Scalar, domain Domain, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize,
308-
typename TIn>
303+
template <typename Scalar, domain Domain, Idx SubgroupSize, typename TIn>
309304
std::vector<sycl::event> compute_level(
310305
const typename committed_descriptor_impl<Scalar, Domain>::kernel_data_struct& kd_struct, const TIn& input,
311306
Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr,
@@ -380,7 +375,7 @@ std::vector<sycl::event> compute_level(
380375
#endif
381376
PORTFFT_LOG_TRACE("Launching kernel for global implementation with global_size", global_range, "local_size",
382377
local_range);
383-
cgh.parallel_for<global_kernel<Scalar, Domain, Mem, LayoutIn, LayoutOut, SubgroupSize>>(
378+
cgh.parallel_for<global_kernel<Scalar, Domain, Mem, SubgroupSize>>(
384379
sycl::nd_range<1>(sycl::range<1>(static_cast<std::size_t>(global_range)),
385380
sycl::range<1>(static_cast<std::size_t>(local_range))),
386381
[=
@@ -394,11 +389,11 @@ std::vector<sycl::event> compute_level(
394389
s, global_logging_config,
395390
#endif
396391
it};
397-
dispatch_level<Scalar, LayoutIn, LayoutOut, SubgroupSize>(
398-
&in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset,
399-
offset_output_imag, subimpl_twiddles, multipliers_between_factors, &loc_for_input[0],
400-
&loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, inner_batches, inclusive_scan, batch_size,
401-
global_data, kh);
392+
dispatch_level<Scalar, SubgroupSize>(&in_acc_or_usm[0] + input_batch_offset, offset_output,
393+
&in_imag_acc_or_usm[0] + input_batch_offset, offset_output_imag,
394+
subimpl_twiddles, multipliers_between_factors, &loc_for_input[0],
395+
&loc_for_twiddles[0], &loc_for_modifier[0], factors_triple,
396+
inner_batches, inclusive_scan, batch_size, global_data, kh);
402397
});
403398
}));
404399
}

src/portfft/common/workgroup.hpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ namespace detail {
5656
/**
5757
* Calculate all dfts in one dimension of the data stored in local memory.
5858
*
59-
* @tparam LayoutIn Input Layout
6059
* @tparam SubgroupSize Size of the subgroup
6160
* @tparam LocalT The type of the local view
6261
* @tparam T Scalar type
@@ -73,7 +72,7 @@ namespace detail {
7372
* @param stride_within_dft Stride between elements of each DFT - also the number of the DFTs in the inner dimension
7473
* @param ndfts_in_outer_dimension Number of DFTs in outer dimension
7574
* @param storage complex storage: interleaved or split
76-
* @param layout_in Input Layout
75+
* @param input_layout the layout of the input data of the transforms
7776
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
7877
* @param MultiplyOnStore Whether the input data is multiplied with some data array after fft computation.
7978
* @param ApplyScaleFactor Whether or not the scale factor is applied
@@ -86,7 +85,7 @@ __attribute__((always_inline)) inline void dimension_dft(
8685
LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, Idx max_num_batches_in_local_mem,
8786
Idx batch_num_in_local, const T* load_modifier_data, const T* store_modifier_data, IdxGlobal batch_num_in_kernel,
8887
Idx dft_size, Idx stride_within_dft, Idx ndfts_in_outer_dimension, complex_storage storage,
89-
detail::layout layout_in, detail::elementwise_multiply multiply_on_load,
88+
detail::layout input_layout, detail::elementwise_multiply multiply_on_load,
9089
detail::elementwise_multiply multiply_on_store, detail::apply_scale_factor apply_scale_factor,
9190
detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store,
9291
global_data_struct<1> global_data) {
@@ -149,7 +148,7 @@ __attribute__((always_inline)) inline void dimension_dft(
149148
working = working && static_cast<Idx>(global_data.sg.get_local_linear_id()) < max_working_tid_in_sg;
150149
}
151150
if (working) {
152-
if (layout_in == detail::layout::BATCH_INTERLEAVED) {
151+
if (input_layout == detail::layout::BATCH_INTERLEAVED) {
153152
global_data.log_message_global(__func__, "loading transposed data from local to private memory");
154153
if (storage == complex_storage::INTERLEAVED_COMPLEX) {
155154
detail::strided_view local_view{
@@ -249,7 +248,7 @@ __attribute__((always_inline)) inline void dimension_dft(
249248
}
250249
}
251250
global_data.log_dump_private("data in registers after computation:", priv, 2 * fact_wi);
252-
if (layout_in == detail::layout::BATCH_INTERLEAVED) {
251+
if (input_layout == detail::layout::BATCH_INTERLEAVED) {
253252
global_data.log_message_global(__func__, "storing transposed data from private to local memory");
254253
if (storage == complex_storage::INTERLEAVED_COMPLEX) {
255254
detail::strided_view local_view{
@@ -313,7 +312,7 @@ __attribute__((always_inline)) inline void dimension_dft(
313312
* @param N Smaller factor of the Problem size
314313
* @param M Larger factor of the problem size
315314
* @param storage complex storage: interleaved or split
316-
* @param layout_in Whether or not the input is transposed
315+
* @param input_layout the layout of the input data of the transforms
317316
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
318317
* @param multiply_on_store Whether the input data is multiplied with some data array after fft computation.
319318
* @param apply_scale_factor Whether or not the scale factor is applied
@@ -325,7 +324,7 @@ template <Idx SubgroupSize, typename LocalT, typename T>
325324
PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor,
326325
Idx max_num_batches_in_local_mem, Idx batch_num_in_local, IdxGlobal batch_num_in_kernel,
327326
const T* load_modifier_data, const T* store_modifier_data, Idx fft_size, Idx N, Idx M,
328-
complex_storage storage, detail::layout layout_in,
327+
complex_storage storage, detail::layout input_layout,
329328
detail::elementwise_multiply multiply_on_load,
330329
detail::elementwise_multiply multiply_on_store,
331330
detail::apply_scale_factor apply_scale_factor, detail::complex_conjugate conjugate_on_load,
@@ -336,14 +335,14 @@ PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T
336335
// column-wise DFTs
337336
detail::dimension_dft<SubgroupSize, LocalT, T>(
338337
loc, loc_twiddles + (2 * M), nullptr, 1, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data,
339-
store_modifier_data, batch_num_in_kernel, N, M, 1, storage, layout_in, multiply_on_load,
338+
store_modifier_data, batch_num_in_kernel, N, M, 1, storage, input_layout, multiply_on_load,
340339
detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, conjugate_on_load,
341340
detail::complex_conjugate::NOT_APPLIED, global_data);
342341
sycl::group_barrier(global_data.it.get_group());
343342
// row-wise DFTs, including twiddle multiplications and scaling
344343
detail::dimension_dft<SubgroupSize, LocalT, T>(
345344
loc, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, batch_num_in_local,
346-
load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, layout_in,
345+
load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, input_layout,
347346
detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor,
348347
detail::complex_conjugate::NOT_APPLIED, conjugate_on_store, global_data);
349348
global_data.log_message_global(__func__, "exited");

0 commit comments

Comments
 (0)