@@ -139,11 +139,11 @@ static inline void fht2idt_core(
139139}
140140
141141template <typename Scalar>
142- void fht2idt_recursive (Tensor2DTyped<Scalar> const & src, Sign sign, int swaps[] ,
143- int swaps_buffer[], Scalar line_buffer[],
144- OutDegree out_degrees[], std::vector<int >& t_B_to_check,
145- std::vector<int >& t_T_to_check,
146- std::vector<bool >& t_processed) {
142+ void _fht2idt_recursive (Tensor2DTyped<Scalar> const & src, Sign sign,
143+ int swaps[], int swaps_buffer[], Scalar line_buffer[],
144+ OutDegree out_degrees[], std::vector<int >& t_B_to_check,
145+ std::vector<int >& t_T_to_check,
146+ std::vector<bool >& t_processed) {
147147 auto const height = src.height ;
148148 if A_UNLIKELY (height <= 1 ) {
149149 return ;
@@ -153,13 +153,13 @@ void fht2idt_recursive(Tensor2DTyped<Scalar> const& src, Sign sign, int swaps[],
153153 Tensor2D const I_B{slice_no_checks (src, h_T, src.height )};
154154
155155 if (I_T.height > 1 ) {
156- fht2idt_recursive (I_T.as <Scalar>(), sign, swaps, swaps_buffer, line_buffer,
157- out_degrees, t_B_to_check, t_T_to_check, t_processed);
156+ _fht2idt_recursive (I_T.as <Scalar>(), sign, swaps, swaps_buffer, line_buffer,
157+ out_degrees, t_B_to_check, t_T_to_check, t_processed);
158158 }
159159 if (I_B.height > 1 ) {
160- fht2idt_recursive (I_B.as <Scalar>(), sign, swaps + h_T, swaps_buffer + h_T,
161- line_buffer, out_degrees, t_B_to_check, t_T_to_check,
162- t_processed);
160+ _fht2idt_recursive (I_B.as <Scalar>(), sign, swaps + h_T, swaps_buffer + h_T,
161+ line_buffer, out_degrees, t_B_to_check, t_T_to_check,
162+ t_processed);
163163 }
164164 std::memcpy (swaps_buffer, swaps, height * sizeof (swaps_buffer[0 ]));
165165 fht2idt_core (height, sign, swaps, swaps_buffer + 0 , swaps_buffer + h_T,
@@ -170,82 +170,121 @@ void fht2idt_recursive(Tensor2DTyped<Scalar> const& src, Sign sign, int swaps[],
170170}
171171
172172template <typename Scalar>
173- void fht2idt_non_recursive (Tensor2DTyped<Scalar> const & src, Sign sign,
174- int swaps[], int swaps_buffer[],
175- Scalar line_buffer[], OutDegree out_degrees[],
176- std::vector<int >& t_B_to_check,
177- std::vector<int >& t_T_to_check,
178- std::vector<bool >& t_processed) {
173+ void _fht2idt_non_recursive (Tensor2DTyped<Scalar> const & src, Sign sign,
174+ int swaps[], int swaps_buffer[],
175+ Scalar line_buffer[], OutDegree out_degrees[],
176+ std::vector<int >& t_B_to_check,
177+ std::vector<int >& t_T_to_check,
178+ std::vector<bool >& t_processed,
179+ std::vector<ADRTTask> const & tasks) {
179180 auto const height = src.height ;
180181 if A_UNLIKELY (height <= 1 ) {
181182 return ;
182183 }
183-
184- non_recursive (
185- height,
186- [&](ADRTTask const & task) {
187- A_NEVER (task.size < 2 );
188- Tensor2D const I_T{slice_no_checks (src, task.start , task.mid )};
189- Tensor2D const I_B{slice_no_checks (src, task.mid , task.stop )};
190- int * cur_swaps_buffer = swaps_buffer + task.start ;
191- int * cur_swaps = swaps + task.start ;
192- std::memcpy (cur_swaps_buffer, cur_swaps,
193- task.size * sizeof (swaps_buffer[0 ]));
194- fht2idt_core (task.size , sign, cur_swaps, cur_swaps_buffer,
195- swaps_buffer + task.mid , line_buffer, I_T.as <Scalar>(),
196- I_B.as <Scalar>(), out_degrees, t_B_to_check, t_T_to_check,
197- t_processed);
198- },
199- [](int val) {
200- return static_cast <int >(div_by_pow2 (static_cast <uint32_t >(val)));
201- });
184+ for (ADRTTask const & task : tasks) {
185+ A_NEVER (task.size < 2 );
186+ Tensor2D const I_T{slice_no_checks (src, task.start , task.mid )};
187+ Tensor2D const I_B{slice_no_checks (src, task.mid , task.stop )};
188+ int * cur_swaps_buffer = swaps_buffer + task.start ;
189+ int * cur_swaps = swaps + task.start ;
190+ std::memcpy (cur_swaps_buffer, cur_swaps,
191+ task.size * sizeof (swaps_buffer[0 ]));
192+ fht2idt_core (task.size , sign, cur_swaps, cur_swaps_buffer,
193+ swaps_buffer + task.mid , line_buffer, I_T.as <Scalar>(),
194+ I_B.as <Scalar>(), out_degrees, t_B_to_check, t_T_to_check,
195+ t_processed);
196+ }
202197}
203198
204199template <typename Scalar>
205- class idt {
200+ struct idt_base {
206201 std::unique_ptr<int []> swaps_buffer;
207202 std::unique_ptr<Scalar[]> line_buffer;
208203 std::unique_ptr<OutDegree[]> out_degrees;
209204 std::vector<int > t_B_to_check;
210205 std::vector<int > t_T_to_check;
211206 std::vector<bool > t_processed;
207+ template <typename SwapsBuffer, typename LineBuffer, typename OutDegrees>
208+ idt_base (SwapsBuffer&& swaps_buffer, LineBuffer&& line_buffer,
209+ OutDegrees&& out_degrees)
210+ : swaps_buffer(std::forward<SwapsBuffer>(swaps_buffer)),
211+ line_buffer (std::forward<LineBuffer>(line_buffer)),
212+ out_degrees(std::forward<OutDegrees>(out_degrees)) {}
212213
213- public:
214- std::unique_ptr<int []> swaps;
215- idt (std::unique_ptr<int []>&& swaps, std::unique_ptr<int []>&& swaps_buffer,
216- std::unique_ptr<Scalar[]>&& line_buffer,
217- std::unique_ptr<OutDegree[]>&& out_degrees)
218- : swaps_buffer{std::move (swaps_buffer)},
219- line_buffer{std::move (line_buffer)},
220- out_degrees{std::move (out_degrees)},
221- swaps{std::move (swaps)} {}
222- static idt<Scalar> create (Tensor2DTyped<Scalar> const & prototype) {
223- std::unique_ptr<int []> swaps (new int [prototype.height ]);
214+ static idt_base<Scalar> create (Tensor2DTyped<Scalar> const & prototype) {
224215 std::unique_ptr<int []> swaps_buffer (new int [prototype.height ]);
225216 std::unique_ptr<Scalar[]> line_buffer (new Scalar[prototype.height ]);
226217 std::unique_ptr<adrt::OutDegree[]> out_degrees (
227218 new adrt::OutDegree[prototype.height ]);
228- return idt (std::move (swaps ), std::move (swaps_buffer ),
229- std::move (line_buffer), std::move (out_degrees));
219+ return idt_base (std::move (swaps_buffer ), std::move (line_buffer ),
220+ std::move (out_degrees));
230221 }
231- void recursive (Tensor2DTyped<Scalar> const & src, Sign sign) {
222+ };
223+
224+ template <typename Scalar>
225+ class idt_recursive {
226+ idt_base<Scalar> base;
227+
228+ public:
229+ std::unique_ptr<int []> swaps;
230+
231+ idt_recursive (idt_base<Scalar>&& base, std::unique_ptr<int []>&& swaps)
232+ : base{std::move (base)}, swaps{std::move (swaps)} {}
233+ static idt_recursive<Scalar> create (Tensor2DTyped<Scalar> const & prototype) {
234+ std::unique_ptr<int []> swaps (new int [prototype.height ]);
235+ return idt_recursive (idt_base<Scalar>::create (prototype), std::move (swaps));
236+ }
237+ void operator ()(Tensor2DTyped<Scalar> const & src, Sign sign) {
232238 std::fill (this ->swaps .get (), this ->swaps .get () + src.height , 0 );
233- fht2idt_recursive (src, sign, this ->swaps .get (), this ->swaps_buffer .get (),
234- this ->line_buffer .get (), this ->out_degrees .get (),
235- this ->t_B_to_check , this ->t_T_to_check ,
236- this ->t_processed );
239+ _fht2idt_recursive (src, sign, this ->swaps .get (),
240+ this ->base .swaps_buffer .get (),
241+ this ->base .line_buffer .get (),
242+ this ->base .out_degrees .get (), this ->base .t_B_to_check ,
243+ this ->base .t_T_to_check , this ->base .t_processed );
237244 }
245+ };
246+
247+ template <typename Scalar>
248+ class idt_non_recursive {
249+ idt_base<Scalar> base;
250+ std::vector<ADRTTask> tasks;
251+
252+ public:
253+ std::unique_ptr<int []> swaps;
254+ idt_non_recursive (idt_base<Scalar>&& base, std::unique_ptr<int []>&& swaps,
255+ std::vector<ADRTTask>&& tasks)
256+ : base{std::move (base)},
257+ swaps{std::move (swaps)},
258+ tasks{std::move (tasks)} {}
259+ static idt_non_recursive<Scalar> create (
260+ Tensor2DTyped<Scalar> const & prototype) {
261+ std::unique_ptr<int []> swaps (new int [prototype.height ]);
262+ std::vector<ADRTTask> tasks;
263+
264+ non_recursive (
265+ prototype.height ,
266+ [&](ADRTTask const & task) { tasks.emplace_back (task); },
267+ [](int val) {
268+ return static_cast <int >(div_by_pow2 (static_cast <uint32_t >(val)));
269+ });
238270
239- void non_recursive (Tensor2DTyped<Scalar> const & src, Sign sign) {
271+ return idt_non_recursive (idt_base<Scalar>::create (prototype),
272+ std::move (swaps), std::move (tasks));
273+ }
274+ void operator ()(Tensor2DTyped<Scalar> const & src, Sign sign) {
240275 std::fill (this ->swaps .get (), this ->swaps .get () + src.height , 0 );
241- fht2idt_non_recursive (src, sign, this ->swaps .get (),
242- this ->swaps_buffer .get (), this ->line_buffer .get (),
243- this ->out_degrees .get (), this ->t_B_to_check ,
244- this ->t_T_to_check , this ->t_processed );
276+ _fht2idt_non_recursive (
277+ src, sign, this ->swaps .get (), this ->base .swaps_buffer .get (),
278+ this ->base .line_buffer .get (), this ->base .out_degrees .get (),
279+ this ->base .t_B_to_check , this ->base .t_T_to_check ,
280+ this ->base .t_processed , this ->tasks );
245281 }
246282};
247283
248284template <typename Scalar>
249- using fht2idt = idt<Scalar>;
285+ using fht2idt_recursive = idt_recursive<Scalar>;
286+
287+ template <typename Scalar>
288+ using fht2idt_non_recursive = idt_non_recursive<Scalar>;
250289
251290} // namespace adrt
0 commit comments