@@ -115,8 +115,6 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors,
115115 * Device function responsible for calling the corresponding sub-implementation
116116 *
117117 * @tparam Scalar Scalar type
118- * @tparam LayoutIn Input layout
119- * @tparam LayoutOut Output layout
120118 * @tparam SubgroupSize Subgroup size
121119 * @param input input pointer
122120 * @param output output pointer
@@ -134,7 +132,7 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors,
134132 * @param global_data global data
135133 * @param kh kernel handler
136134 */
137- template <typename Scalar, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize>
135+ template <typename Scalar, Idx SubgroupSize>
138136PORTFFT_INLINE void dispatch_level (const Scalar* input, Scalar* output, const Scalar* input_imag, Scalar* output_imag,
139137 const Scalar* implementation_twiddles, const Scalar* store_modifier_data,
140138 Scalar* input_loc, Scalar* twiddles_loc, Scalar* store_modifier_loc,
@@ -156,16 +154,16 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc
156154 batch_size, global_data, kh, static_cast <const Scalar*>(nullptr ),
157155 store_modifier_data, static_cast <Scalar*>(nullptr ), store_modifier_loc);
158156 } else if (level == detail::level::SUBGROUP) {
159- subgroup_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
160- input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,
161- output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data,
162- kh, static_cast <const Scalar*>(nullptr ), store_modifier_data, static_cast <Scalar*>( nullptr ) ,
163- store_modifier_loc);
157+ subgroup_impl<SubgroupSize, Scalar>(input + outer_batch_offset, output + outer_batch_offset,
158+ input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc ,
159+ twiddles_loc, batch_size, implementation_twiddles, global_data, kh ,
160+ static_cast <const Scalar*>(nullptr ), store_modifier_data,
161+ static_cast <Scalar*>( nullptr ), store_modifier_loc);
164162 } else if (level == detail::level::WORKGROUP) {
165- workgroup_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
166- input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,
167- output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data,
168- kh, static_cast <Scalar*>(nullptr ), store_modifier_data);
163+ workgroup_impl<SubgroupSize, Scalar>(input + outer_batch_offset, output + outer_batch_offset,
164+ input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc ,
165+ twiddles_loc, batch_size, implementation_twiddles, global_data, kh ,
166+ static_cast <Scalar*>(nullptr ), store_modifier_data);
169167 }
170168 sycl::group_barrier (global_data.it .get_group ());
171169 }
@@ -277,8 +275,6 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
277275 * Prepares the launch of fft compute at a particular level
278276 * @tparam Scalar Scalar type
279277 * @tparam Domain Domain of FFT
280- * @tparam LayoutIn Input layout
281- * @tparam LayoutOut output layout
282278 * @tparam SubgroupSize subgroup size
283279 * @tparam TIn input type
284280 * @param kd_struct associated kernel data struct with the factor
@@ -304,8 +300,7 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
304300 * @param queue queue
305301 * @return vector events, one for each batch in l2
306302 */
307- template <typename Scalar, domain Domain, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize,
308- typename TIn>
303+ template <typename Scalar, domain Domain, Idx SubgroupSize, typename TIn>
309304std::vector<sycl::event> compute_level (
310305 const typename committed_descriptor_impl<Scalar, Domain>::kernel_data_struct& kd_struct, const TIn& input,
311306 Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr,
@@ -380,7 +375,7 @@ std::vector<sycl::event> compute_level(
380375#endif
381376 PORTFFT_LOG_TRACE (" Launching kernel for global implementation with global_size" , global_range, " local_size" ,
382377 local_range);
383- cgh.parallel_for <global_kernel<Scalar, Domain, Mem, LayoutIn, LayoutOut, SubgroupSize>>(
378+ cgh.parallel_for <global_kernel<Scalar, Domain, Mem, SubgroupSize>>(
384379 sycl::nd_range<1 >(sycl::range<1 >(static_cast <std::size_t >(global_range)),
385380 sycl::range<1 >(static_cast <std::size_t >(local_range))),
386381 [=
@@ -394,11 +389,11 @@ std::vector<sycl::event> compute_level(
394389 s, global_logging_config,
395390#endif
396391 it};
397- dispatch_level<Scalar, LayoutIn, LayoutOut, SubgroupSize>(
398- &in_acc_or_usm[ 0 ] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0 ] + input_batch_offset,
399- offset_output_imag, subimpl_twiddles, multipliers_between_factors, &loc_for_input[0 ],
400- &loc_for_twiddles[0 ], &loc_for_modifier[0 ], factors_triple, inner_batches, inclusive_scan, batch_size ,
401- global_data, kh);
392+ dispatch_level<Scalar, SubgroupSize>(&in_acc_or_usm[ 0 ] + input_batch_offset, offset_output,
393+ &in_imag_acc_or_usm[0 ] + input_batch_offset, offset_output_imag ,
394+ subimpl_twiddles, multipliers_between_factors, &loc_for_input[0 ],
395+ &loc_for_twiddles[0 ], &loc_for_modifier[0 ], factors_triple,
396+ inner_batches, inclusive_scan, batch_size, global_data, kh);
402397 });
403398 }));
404399 }
0 commit comments