Skip to content

Commit eb346a8

Browse files
committed
feat!: store precomputed tree for idt_recursive algorithm
for measurable speed up
1 parent 2e41650 commit eb346a8

File tree

4 files changed

+122
-75
lines changed

4 files changed

+122
-75
lines changed

adrtlib/_adrtlib.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,23 @@ auto py_ids(Image2D &image, adrt::Sign sign, Recursive recursive) {
7979
template <typename Scalar>
8080
static auto py_idt_visit(adrt::Tensor2D const &tensor, adrt::Sign sign,
8181
Recursive recursive) {
82-
auto idt_core = adrt::idt<Scalar>::create(tensor.as<Scalar>());
82+
std::unique_ptr<int[]> swaps;
8383
if (recursive == Recursive::Yes) {
84-
idt_core.recursive(tensor.as<Scalar>(), sign);
84+
auto idt_recursive =
85+
adrt::idt_recursive<Scalar>::create(tensor.as<Scalar>());
86+
idt_recursive(tensor.as<Scalar>(), sign);
87+
swaps = std::move(idt_recursive.swaps);
8588
} else {
86-
idt_core.non_recursive(tensor.as<Scalar>(), sign);
89+
auto idt_non_recursive =
90+
adrt::idt_non_recursive<Scalar>::create(tensor.as<Scalar>());
91+
idt_non_recursive(tensor.as<Scalar>(), sign);
92+
swaps = std::move(idt_non_recursive.swaps);
8793
}
88-
nb::capsule swaps_owner(idt_core.swaps.get(),
94+
nb::capsule swaps_owner(swaps.get(),
8995
[](void *p) noexcept { delete[] (int *)p; });
9096

9197
return nb::ndarray<nb::numpy, int, nb::ndim<1>, nb::device::cpu>(
92-
/* data = */ idt_core.swaps.release(),
98+
/* data = */ swaps.release(),
9399
/* shape = */ {static_cast<size_t>(tensor.height)},
94100
/* owner = */ swaps_owner);
95101
}

adrtlib/benchmark/adrtlib_benchmark.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,16 @@ static void BM_fht2idt(benchmark::State &state, IsRecursive is_recursive) {
8787
reinterpret_cast<uint8_t *>(src.get())};
8888
adrt::Sign const sign = adrt::Sign::Positive;
8989

90-
auto idt_code = adrt::idt<float>::create(tensor.as<float>());
91-
9290
if (is_recursive == IsRecursive::Yes) {
91+
auto idt_recursive = adrt::idt_recursive<float>::create(tensor.as<float>());
9392
for (auto _ : state) {
94-
idt_code.recursive(tensor.as<float>(), sign);
93+
idt_recursive(tensor.as<float>(), sign);
9594
}
9695
} else {
96+
auto idt_non_recursive =
97+
adrt::idt_non_recursive<float>::create(tensor.as<float>());
9798
for (auto _ : state) {
98-
idt_code.non_recursive(tensor.as<float>(), sign);
99+
idt_non_recursive(tensor.as<float>(), sign);
99100
}
100101
}
101102
state.SetBytesProcessed(int64_t(state.iterations()) *

adrtlib/include/adrtlib/fht2idt.hpp

Lines changed: 99 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,11 @@ static inline void fht2idt_core(
139139
}
140140

141141
template <typename Scalar>
142-
void fht2idt_recursive(Tensor2DTyped<Scalar> const& src, Sign sign, int swaps[],
143-
int swaps_buffer[], Scalar line_buffer[],
144-
OutDegree out_degrees[], std::vector<int>& t_B_to_check,
145-
std::vector<int>& t_T_to_check,
146-
std::vector<bool>& t_processed) {
142+
void _fht2idt_recursive(Tensor2DTyped<Scalar> const& src, Sign sign,
143+
int swaps[], int swaps_buffer[], Scalar line_buffer[],
144+
OutDegree out_degrees[], std::vector<int>& t_B_to_check,
145+
std::vector<int>& t_T_to_check,
146+
std::vector<bool>& t_processed) {
147147
auto const height = src.height;
148148
if A_UNLIKELY (height <= 1) {
149149
return;
@@ -153,13 +153,13 @@ void fht2idt_recursive(Tensor2DTyped<Scalar> const& src, Sign sign, int swaps[],
153153
Tensor2D const I_B{slice_no_checks(src, h_T, src.height)};
154154

155155
if (I_T.height > 1) {
156-
fht2idt_recursive(I_T.as<Scalar>(), sign, swaps, swaps_buffer, line_buffer,
157-
out_degrees, t_B_to_check, t_T_to_check, t_processed);
156+
_fht2idt_recursive(I_T.as<Scalar>(), sign, swaps, swaps_buffer, line_buffer,
157+
out_degrees, t_B_to_check, t_T_to_check, t_processed);
158158
}
159159
if (I_B.height > 1) {
160-
fht2idt_recursive(I_B.as<Scalar>(), sign, swaps + h_T, swaps_buffer + h_T,
161-
line_buffer, out_degrees, t_B_to_check, t_T_to_check,
162-
t_processed);
160+
_fht2idt_recursive(I_B.as<Scalar>(), sign, swaps + h_T, swaps_buffer + h_T,
161+
line_buffer, out_degrees, t_B_to_check, t_T_to_check,
162+
t_processed);
163163
}
164164
std::memcpy(swaps_buffer, swaps, height * sizeof(swaps_buffer[0]));
165165
fht2idt_core(height, sign, swaps, swaps_buffer + 0, swaps_buffer + h_T,
@@ -170,82 +170,121 @@ void fht2idt_recursive(Tensor2DTyped<Scalar> const& src, Sign sign, int swaps[],
170170
}
171171

172172
template <typename Scalar>
173-
void fht2idt_non_recursive(Tensor2DTyped<Scalar> const& src, Sign sign,
174-
int swaps[], int swaps_buffer[],
175-
Scalar line_buffer[], OutDegree out_degrees[],
176-
std::vector<int>& t_B_to_check,
177-
std::vector<int>& t_T_to_check,
178-
std::vector<bool>& t_processed) {
173+
void _fht2idt_non_recursive(Tensor2DTyped<Scalar> const& src, Sign sign,
174+
int swaps[], int swaps_buffer[],
175+
Scalar line_buffer[], OutDegree out_degrees[],
176+
std::vector<int>& t_B_to_check,
177+
std::vector<int>& t_T_to_check,
178+
std::vector<bool>& t_processed,
179+
std::vector<ADRTTask> const& tasks) {
179180
auto const height = src.height;
180181
if A_UNLIKELY (height <= 1) {
181182
return;
182183
}
183-
184-
non_recursive(
185-
height,
186-
[&](ADRTTask const& task) {
187-
A_NEVER(task.size < 2);
188-
Tensor2D const I_T{slice_no_checks(src, task.start, task.mid)};
189-
Tensor2D const I_B{slice_no_checks(src, task.mid, task.stop)};
190-
int* cur_swaps_buffer = swaps_buffer + task.start;
191-
int* cur_swaps = swaps + task.start;
192-
std::memcpy(cur_swaps_buffer, cur_swaps,
193-
task.size * sizeof(swaps_buffer[0]));
194-
fht2idt_core(task.size, sign, cur_swaps, cur_swaps_buffer,
195-
swaps_buffer + task.mid, line_buffer, I_T.as<Scalar>(),
196-
I_B.as<Scalar>(), out_degrees, t_B_to_check, t_T_to_check,
197-
t_processed);
198-
},
199-
[](int val) {
200-
return static_cast<int>(div_by_pow2(static_cast<uint32_t>(val)));
201-
});
184+
for (ADRTTask const& task : tasks) {
185+
A_NEVER(task.size < 2);
186+
Tensor2D const I_T{slice_no_checks(src, task.start, task.mid)};
187+
Tensor2D const I_B{slice_no_checks(src, task.mid, task.stop)};
188+
int* cur_swaps_buffer = swaps_buffer + task.start;
189+
int* cur_swaps = swaps + task.start;
190+
std::memcpy(cur_swaps_buffer, cur_swaps,
191+
task.size * sizeof(swaps_buffer[0]));
192+
fht2idt_core(task.size, sign, cur_swaps, cur_swaps_buffer,
193+
swaps_buffer + task.mid, line_buffer, I_T.as<Scalar>(),
194+
I_B.as<Scalar>(), out_degrees, t_B_to_check, t_T_to_check,
195+
t_processed);
196+
}
202197
}
203198

204199
template <typename Scalar>
205-
class idt {
200+
struct idt_base {
206201
std::unique_ptr<int[]> swaps_buffer;
207202
std::unique_ptr<Scalar[]> line_buffer;
208203
std::unique_ptr<OutDegree[]> out_degrees;
209204
std::vector<int> t_B_to_check;
210205
std::vector<int> t_T_to_check;
211206
std::vector<bool> t_processed;
207+
template <typename SwapsBuffer, typename LineBuffer, typename OutDegrees>
208+
idt_base(SwapsBuffer&& swaps_buffer, LineBuffer&& line_buffer,
209+
OutDegrees&& out_degrees)
210+
: swaps_buffer(std::forward<SwapsBuffer>(swaps_buffer)),
211+
line_buffer(std::forward<LineBuffer>(line_buffer)),
212+
out_degrees(std::forward<OutDegrees>(out_degrees)) {}
212213

213-
public:
214-
std::unique_ptr<int[]> swaps;
215-
idt(std::unique_ptr<int[]>&& swaps, std::unique_ptr<int[]>&& swaps_buffer,
216-
std::unique_ptr<Scalar[]>&& line_buffer,
217-
std::unique_ptr<OutDegree[]>&& out_degrees)
218-
: swaps_buffer{std::move(swaps_buffer)},
219-
line_buffer{std::move(line_buffer)},
220-
out_degrees{std::move(out_degrees)},
221-
swaps{std::move(swaps)} {}
222-
static idt<Scalar> create(Tensor2DTyped<Scalar> const& prototype) {
223-
std::unique_ptr<int[]> swaps(new int[prototype.height]);
214+
static idt_base<Scalar> create(Tensor2DTyped<Scalar> const& prototype) {
224215
std::unique_ptr<int[]> swaps_buffer(new int[prototype.height]);
225216
std::unique_ptr<Scalar[]> line_buffer(new Scalar[prototype.height]);
226217
std::unique_ptr<adrt::OutDegree[]> out_degrees(
227218
new adrt::OutDegree[prototype.height]);
228-
return idt(std::move(swaps), std::move(swaps_buffer),
229-
std::move(line_buffer), std::move(out_degrees));
219+
return idt_base(std::move(swaps_buffer), std::move(line_buffer),
220+
std::move(out_degrees));
230221
}
231-
void recursive(Tensor2DTyped<Scalar> const& src, Sign sign) {
222+
};
223+
224+
template <typename Scalar>
225+
class idt_recursive {
226+
idt_base<Scalar> base;
227+
228+
public:
229+
std::unique_ptr<int[]> swaps;
230+
231+
idt_recursive(idt_base<Scalar>&& base, std::unique_ptr<int[]>&& swaps)
232+
: base{std::move(base)}, swaps{std::move(swaps)} {}
233+
static idt_recursive<Scalar> create(Tensor2DTyped<Scalar> const& prototype) {
234+
std::unique_ptr<int[]> swaps(new int[prototype.height]);
235+
return idt_recursive(idt_base<Scalar>::create(prototype), std::move(swaps));
236+
}
237+
void operator()(Tensor2DTyped<Scalar> const& src, Sign sign) {
232238
std::fill(this->swaps.get(), this->swaps.get() + src.height, 0);
233-
fht2idt_recursive(src, sign, this->swaps.get(), this->swaps_buffer.get(),
234-
this->line_buffer.get(), this->out_degrees.get(),
235-
this->t_B_to_check, this->t_T_to_check,
236-
this->t_processed);
239+
_fht2idt_recursive(src, sign, this->swaps.get(),
240+
this->base.swaps_buffer.get(),
241+
this->base.line_buffer.get(),
242+
this->base.out_degrees.get(), this->base.t_B_to_check,
243+
this->base.t_T_to_check, this->base.t_processed);
237244
}
245+
};
246+
247+
template <typename Scalar>
248+
class idt_non_recursive {
249+
idt_base<Scalar> base;
250+
std::vector<ADRTTask> tasks;
251+
252+
public:
253+
std::unique_ptr<int[]> swaps;
254+
idt_non_recursive(idt_base<Scalar>&& base, std::unique_ptr<int[]>&& swaps,
255+
std::vector<ADRTTask>&& tasks)
256+
: base{std::move(base)},
257+
swaps{std::move(swaps)},
258+
tasks{std::move(tasks)} {}
259+
static idt_non_recursive<Scalar> create(
260+
Tensor2DTyped<Scalar> const& prototype) {
261+
std::unique_ptr<int[]> swaps(new int[prototype.height]);
262+
std::vector<ADRTTask> tasks;
263+
264+
non_recursive(
265+
prototype.height,
266+
[&](ADRTTask const& task) { tasks.emplace_back(task); },
267+
[](int val) {
268+
return static_cast<int>(div_by_pow2(static_cast<uint32_t>(val)));
269+
});
238270

239-
void non_recursive(Tensor2DTyped<Scalar> const& src, Sign sign) {
271+
return idt_non_recursive(idt_base<Scalar>::create(prototype),
272+
std::move(swaps), std::move(tasks));
273+
}
274+
void operator()(Tensor2DTyped<Scalar> const& src, Sign sign) {
240275
std::fill(this->swaps.get(), this->swaps.get() + src.height, 0);
241-
fht2idt_non_recursive(src, sign, this->swaps.get(),
242-
this->swaps_buffer.get(), this->line_buffer.get(),
243-
this->out_degrees.get(), this->t_B_to_check,
244-
this->t_T_to_check, this->t_processed);
276+
_fht2idt_non_recursive(
277+
src, sign, this->swaps.get(), this->base.swaps_buffer.get(),
278+
this->base.line_buffer.get(), this->base.out_degrees.get(),
279+
this->base.t_B_to_check, this->base.t_T_to_check,
280+
this->base.t_processed, this->tasks);
245281
}
246282
};
247283

248284
template <typename Scalar>
249-
using fht2idt = idt<Scalar>;
285+
using fht2idt_recursive = idt_recursive<Scalar>;
286+
287+
template <typename Scalar>
288+
using fht2idt_non_recursive = idt_non_recursive<Scalar>;
250289

251290
} // namespace adrt

adrtlib/test/adrtlib_test.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,17 +296,18 @@ static std::vector<ADRTTestCase> GenerateTestFHT2DSCases() {
296296
FunctionPair(
297297
[](adrt::Tensor2DTyped<float> const &dst,
298298
adrt::Tensor2DTyped<float> const &src, adrt::Sign sign) {
299-
auto fht2idt_core = adrt::fht2idt<float>::create(src);
300-
fht2idt_core.recursive(src, sign);
301-
unswap_tensor(dst, src, fht2idt_core.swaps.get());
299+
auto idt_recursive = adrt::idt_recursive<float>::create(src);
300+
idt_recursive(src, sign);
301+
unswap_tensor(dst, src, idt_recursive.swaps.get());
302302
},
303303
"fht2idt_recursive", FunctionType::fht2dt, IsInplace::Yes),
304304
FunctionPair(
305305
[](adrt::Tensor2DTyped<float> const &dst,
306306
adrt::Tensor2DTyped<float> const &src, adrt::Sign sign) {
307-
auto fht2idt_core = adrt::fht2idt<float>::create(src);
308-
fht2idt_core.non_recursive(src, sign);
309-
unswap_tensor(dst, src, fht2idt_core.swaps.get());
307+
auto idt_non_recursive =
308+
adrt::idt_non_recursive<float>::create(src);
309+
idt_non_recursive(src, sign);
310+
unswap_tensor(dst, src, idt_non_recursive.swaps.get());
310311
},
311312
"fht2idt_non_recursive", FunctionType::fht2dt, IsInplace::Yes),
312313
FunctionPair(

0 commit comments

Comments
 (0)