Skip to content

Commit b9aa679

Browse files
committed
Add vectorized PrepareDerivatives, test, and benchmark
1 parent ce0fb6a commit b9aa679

File tree

6 files changed

+145
-1
lines changed

6 files changed

+145
-1
lines changed

cpp/bench/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ target_link_libraries(bench-hl-filter src-lib benchmark::benchmark_main)
2828
add_executable(bench-exps exps.cpp)
2929
target_link_libraries(bench-exps src-lib benchmark::benchmark_main)
3030

31+
add_executable(bench-prepare prepare-derivatives.cpp)
32+
target_link_libraries(bench-prepare src-lib benchmark::benchmark_main)
33+
3134
if (ENABLE_CILK)
3235
add_executable(bench-fftw-cilk fftw-cilk.cpp)
3336
target_link_libraries(bench-fftw-cilk fftw-cpp benchmark::benchmark fftw3_threads)

cpp/bench/prepare-derivatives.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include "PrepareDerivatives.hpp"
2+
3+
#include <benchmark/benchmark.h>
4+
5+
static constexpr auto N_WARMUP_ITERS = 5;
6+
7+
using namespace ahr;
8+
template <class Prepare> static void BM_Prepare(benchmark::State &state) {
9+
Dim const X = state.range(0);
10+
Dim const Y = state.range(1);
11+
12+
Grid grid{1, X, Y};
13+
Prepare prepare{grid};
14+
15+
auto buf = grid.cBufXY();
16+
auto buf2 = grid.cBufXY();
17+
auto buf3 = grid.cBufXY();
18+
19+
// Warm-up
20+
for (int i = 0; i < N_WARMUP_ITERS; i++) {
21+
prepare(buf, {buf2, buf3});
22+
}
23+
24+
for (auto _ : state) {
25+
prepare(buf, {buf2, buf3});
26+
}
27+
}
28+
29+
BENCHMARK(BM_Prepare<PrepareDerivatives>)
30+
->ArgsProduct({{2048, 4096, 8192}, {2048, 4096, 8192}})
31+
->Unit(benchmark::kMillisecond);
32+
33+
BENCHMARK(BM_Prepare<PrepareDerivativesVector>)
34+
->ArgsProduct({{2048, 4096, 8192}, {2048, 4096, 8192}})
35+
->Unit(benchmark::kMillisecond);

cpp/lib/PrepareDerivatives.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "PrepareDerivatives.hpp"
22

3+
#include <eve/module/core.hpp>
4+
35
namespace ahr {
46

57
void PrepareDerivatives::operator()(View::C_XY const &in, DxDy<View::C_XY> out) const {
@@ -9,4 +11,48 @@ void PrepareDerivatives::operator()(View::C_XY const &in, DxDy<View::C_XY> out)
911
});
1012
}
1113

14+
void PrepareDerivativesVector::operator()(View::C_XY const &in, DxDy<View::C_XY> out) const {
15+
eve::logical<VIdx> const even_mask{[](int idx, int) { return idx % 2 == 0; }};
16+
VReal const kx_v_init{[](int idx, int) { return Real(idx / 2); }};
17+
for (Dim ky = 0; ky < grid.KY; ky += KY_TILE) {
18+
// broadcast ky values
19+
using TileReal = std::array<VReal, KY_TILE>;
20+
TileReal ky_v;
21+
for (Dim tile_ky = 0; tile_ky < KY_TILE; ++tile_ky) {
22+
ky_v[tile_ky] = ky_(ky + tile_ky);
23+
}
24+
25+
// Initialize kx values
26+
VReal kx_v = kx_v_init;
27+
28+
Dim kx = 0;
29+
for (; kx <= grid.KX - C_WIDTH; kx += C_WIDTH, kx_v += C_WIDTH) {
30+
auto input = [&](int i) { return (Real *)&in(kx, ky + i); };
31+
auto out_dx = [&](int i) { return (Real *)&out.DX(kx, ky + i); };
32+
auto out_dy = [&](int i) { return (Real *)&out.DY(kx, ky + i); };
33+
34+
for (int i = 0; i < KY_TILE; ++i) {
35+
auto in_v = VReal{input(i)};
36+
// (a + bi) * i -> (-b + ai)
37+
// swap real and imaginary parts
38+
auto swapped = eve::swap_adjacent(in_v);
39+
// selectively negate
40+
auto mul_with_i = eve::minus[even_mask](swapped);
41+
// normalize
42+
auto in_norm = mul_with_i * XYNorm;
43+
44+
eve::store(kx_v * in_norm, out_dx(i));
45+
eve::store(ky_v[i] * in_norm, out_dy(i));
46+
}
47+
}
48+
49+
// tail
50+
for (; kx < grid.KX; ++kx) {
51+
for (int i = 0; i < KY_TILE; ++i) {
52+
out.DX(kx, ky + i) = kx_(kx) * 1i * in(kx, ky + i) * XYNorm;
53+
out.DY(kx, ky + i) = ky_(ky + i) * 1i * in(kx, ky + i) * XYNorm;
54+
}
55+
}
56+
}
57+
}
1258
} // namespace ahr

cpp/lib/PrepareDerivatives.hpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
#include "constants.hpp"
44
#include "grid.hpp"
5+
#include <eve/wide.hpp>
56

67
namespace ahr {
78

89
class PrepareDerivatives {
10+
protected:
911
using View = Grid::View;
1012
template <class T> using DxDy = Grid::DxDy<T>;
1113

@@ -14,7 +16,7 @@ class PrepareDerivatives {
1416

1517
void operator()(View::C_XY const &in, DxDy<View::C_XY> out) const;
1618

17-
private:
19+
protected:
1820
Grid const &grid;
1921

2022
/// Normalization factor for FFT
@@ -26,4 +28,22 @@ class PrepareDerivatives {
2628
}
2729
[[nodiscard]] Real kx_(Dim kx) const { return Real(kx); }
2830
};
31+
32+
class PrepareDerivativesVector : protected PrepareDerivatives {
33+
public:
34+
explicit PrepareDerivativesVector(Grid const &grid) : PrepareDerivatives(grid) {}
35+
36+
void operator()(View::C_XY const &in, DxDy<View::C_XY> out) const;
37+
38+
protected:
39+
using VReal = eve::wide<Real>;
40+
41+
// TODO vec_utils
42+
using VIdx = eve::wide<long long>;
43+
static auto constexpr R_WIDTH = VReal::size();
44+
static auto constexpr C_WIDTH = VReal::size() / 2;
45+
46+
static constexpr auto KY_TILE = 4;
47+
};
48+
2949
} // namespace ahr

cpp/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ add_viriato_test(exps.cpp)
2323
add_viriato_test(filter.cpp)
2424
add_viriato_test(grid.cpp)
2525
add_viriato_test(transformer.cpp)
26+
add_viriato_test(prepare-derivatives.cpp)
2627

2728
add_viriato_test(naive-energies.cpp)
2829
add_viriato_test(naive-moments.cpp)

cpp/test/prepare-derivatives.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "PrepareDerivatives.hpp"
2+
3+
#include "debug.hpp"
4+
#include "util.hpp"
5+
6+
#include <gtest/gtest.h>
7+
8+
namespace ahr {
9+
10+
template <typename TestedPrepare> class TestPrepareDerivatives : public ::testing::Test {
11+
protected:
12+
Grid grid{5, 32, 32};
13+
PrepareDerivatives prepare{grid};
14+
TestedPrepare prepare_t{grid};
15+
};
16+
17+
using Types = ::testing::Types<PrepareDerivativesVector>;
18+
TYPED_TEST_SUITE(TestPrepareDerivatives, Types);
19+
20+
TYPED_TEST(TestPrepareDerivatives, Prepare) {
21+
auto const KX = this->grid.KX, KY = this->grid.KY;
22+
auto buf = this->grid.cBufXY(), buf_t = this->grid.cBufXY();
23+
Grid::DxDy<Grid::Buf::C_XY> bufD{KX, KY};
24+
Grid::DxDy<Grid::Buf::C_XY> bufD_t{KX, KY};
25+
26+
// initialize buffers
27+
this->grid.for_each_kxky([&](Dim kx, Dim ky) {
28+
buf(kx, ky) = buf_t(kx, ky) =
29+
1024.0 * Complex{std::sin(2 * pi * kx / KX), std::cos(2 * pi * ky / KY)};
30+
});
31+
32+
this->prepare(buf, bufD);
33+
this->prepare_t(buf_t, bufD_t);
34+
35+
ASSERT_THAT(bufD_t.DX.to_mdspan(), MdspanElementsAllClose(bufD.DX.to_mdspan(), 1e-16, 1e-15));
36+
ASSERT_THAT(bufD_t.DY.to_mdspan(), MdspanElementsAllClose(bufD.DY.to_mdspan(), 1e-16, 1e-15));
37+
}
38+
39+
} // namespace ahr

0 commit comments

Comments
 (0)