@@ -40,9 +40,11 @@ namespace detail {
4040 * @param is_batch_interleaved is the input data layout batch interleaved
4141 * @param workgroup_size The size of the work-group. Must be divisible by 2.
4242 */
43- PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup (bool is_batch_interleaved,
44- Idx workgroup_size) noexcept {
45- return is_batch_interleaved ? workgroup_size / 2 : 1 ;
43+ PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup (bool /* is_batch_interleaved*/ ,
44+ Idx /* workgroup_size*/ ) noexcept {
45+ // TODO reenable when tests are passing
46+ // return is_batch_interleaved ? workgroup_size / 2 : 1;
47+ return 1 ;
4648}
4749
4850/* *
@@ -110,8 +112,9 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima
110112 const IdxGlobal input_distance = kh.get_specialization_constant <detail::SpecConstInputDistance>();
111113 const IdxGlobal output_distance = kh.get_specialization_constant <detail::SpecConstOutputDistance>();
112114
113- const bool is_input_batch_interleaved = input_stride == n_transforms && input_distance == 1 ;
114- const bool is_input_packed = input_stride == 1 && input_distance == fft_size;
115+ // TODO reable when tests are passing
116+ const bool is_input_batch_interleaved = false ; // input_stride == n_transforms && input_distance == 1;
117+ const bool is_input_packed = input_stride == 1 && input_distance == fft_size;
115118
116119 global_data.log_message_global (__func__, " entered" , " fft_size" , fft_size, " n_transforms" , n_transforms);
117120 Idx num_workgroups = static_cast <Idx>(global_data.it .get_group_range (0 ));
@@ -280,8 +283,8 @@ struct committed_descriptor_impl<Scalar, Domain>::run_kernel_struct<SubgroupSize
280283 PORTFFT_LOG_FUNCTION_ENTRY ();
281284 auto & kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels .at (0 )
282285 : dimension_data.backward_kernels .at (0 );
283- Idx num_batches_in_local_mem =
284- input_layout == layout::BATCH_INTERLEAVED ? kernel_data.used_sg_size * PORTFFT_SGS_IN_WG / 2 : 1 ;
286+ Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup (
287+ input_layout == layout::BATCH_INTERLEAVED, kernel_data.used_sg_size * PORTFFT_SGS_IN_WG) ;
285288 constexpr detail::memory Mem = std::is_pointer_v<TOut> ? detail::memory::USM : detail::memory::BUFFER;
286289 Scalar* twiddles = kernel_data.twiddles_forward .get ();
287290 std::size_t local_elements =
@@ -355,8 +358,8 @@ struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struc
355358 // working memory + twiddles for subgroup impl for the two sizes
356359 Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup (
357360 input_layout == layout::BATCH_INTERLEAVED, used_sg_size * PORTFFT_SGS_IN_WG);
358- return detail::pad_local ( static_cast <std::size_t >(2 * num_batches_in_local_mem) * length,
359- bank_lines_per_pad_wg ( 2 * static_cast <std::size_t >(sizeof (Scalar)) * m) ) +
361+ const auto bank_lines_per_pad = bank_lines_per_pad_wg ( 2 * static_cast <std::size_t >(sizeof (Scalar)) * m);
362+ return detail::pad_local ( static_cast <std::size_t >(2 * num_batches_in_local_mem) * length, bank_lines_per_pad ) +
360363 2 * (m + n);
361364 }
362365};
0 commit comments