Skip to content

Commit c254187

Browse files
committed
Extract filtering to Filter class (TODO tests). Add cached versions and benchmark.
1 parent fb465de commit c254187

File tree

8 files changed

+130
-12
lines changed

8 files changed

+130
-12
lines changed

cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ target_link_libraries(adaptive_hermite_refinement argparse mdspan)
6161
add_library(src-lib OBJECT lib/Naive.cpp lib/Naive.hpp lib/HermiteRunner.cpp lib/HermiteRunner.hpp
6262
lib/Transformer.hpp lib/Transformer.cpp
6363
lib/Exporter.hpp lib/Exporter.cpp
64+
lib/Filter.hpp lib/Filter.cpp
6465
)
6566
target_include_directories(src-lib PUBLIC lib/)
6667
target_link_libraries(src-lib mdspan fftw-cpp cnpy spdlog::spdlog)

cpp/bench/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ FetchContent_MakeAvailable(google-benchmark)
2222
add_executable(bench-naive naive.cpp main-no-log.cpp)
2323
target_link_libraries(bench-naive src-lib benchmark::benchmark)
2424

25+
add_executable(bench-hl-filter hl-filter.cpp)
26+
target_link_libraries(bench-hl-filter src-lib benchmark::benchmark_main)
27+
2528
if (ENABLE_CILK)
2629
add_executable(bench-fftw-cilk fftw-cilk.cpp)
2730
target_link_libraries(bench-fftw-cilk fftw-cpp benchmark::benchmark fftw3_threads)

cpp/bench/hl-filter.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "Filter.hpp"
2+
3+
#include <benchmark/benchmark.h>
4+
5+
static constexpr auto N_WARMUP_ITERS = 5;
6+
7+
using namespace ahr;
8+
template <class Filter> static void BM_HouLiFilter(benchmark::State &state) {
9+
Dim const X = state.range(0);
10+
Dim const Y = state.range(1);
11+
12+
Grid grid{1, X, Y};
13+
Filter filter{grid};
14+
15+
auto buf = grid.cBufXY();
16+
17+
// Warm-up
18+
for (int i = 0; i < N_WARMUP_ITERS; i++) {
19+
filter(buf);
20+
}
21+
22+
for (auto _ : state) {
23+
filter(buf);
24+
}
25+
}
26+
27+
BENCHMARK(BM_HouLiFilter<HouLiFilter>)
28+
->ArgsProduct({{2048, 4096, 8192}, {2048, 4096, 8192}})
29+
->Unit(benchmark::kMillisecond);
30+
31+
BENCHMARK(BM_HouLiFilter<HouLiFilterCached>)
32+
->ArgsProduct({{2048, 4096, 8192}, {2048, 4096, 8192}})
33+
->Unit(benchmark::kMillisecond);
34+
35+
BENCHMARK(BM_HouLiFilter<HouLiFilterCached1D>)
36+
->ArgsProduct({{2048, 4096, 8192}, {2048, 4096, 8192}})
37+
->Unit(benchmark::kMillisecond);

cpp/bench/naive.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
using namespace ahr;
77
static void BM_Naive(benchmark::State &state) {
8-
std::ostringstream oss;
9-
108
Dim const M = state.range(0);
119
Dim const X = state.range(1);
1210
Dim const N = state.range(2);

cpp/lib/Filter.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//
2+
// Created by Luka on 11/2/2024.
3+
//
4+
5+
#include "Filter.hpp"
6+
namespace ahr {
7+
void HouLiFilter::operator()(Grid::View::C_XY view) {
8+
grid.for_each_kxky([&](Dim kx, Dim ky) {
9+
view(kx, ky) *=
10+
exp(-36.0 * pow(kx_(kx) / grid.KX, 36.0)) * exp(-36.0 * pow(ky_(ky) / grid.KY, 36.0));
11+
});
12+
}
13+
14+
HouLiFilterCached::HouLiFilterCached(Grid const &grid)
15+
: HouLiFilter(grid), factors(std::array{grid.KX, grid.KY}) {
16+
grid.for_each_kxky([&](Dim kx, Dim ky) {
17+
factors(kx, ky) =
18+
exp(-36.0 * pow(kx_(kx) / grid.KX, 36.0)) * exp(-36.0 * pow(ky_(ky) / grid.KY, 36.0));
19+
});
20+
}
21+
22+
void HouLiFilterCached::operator()(Grid::View::C_XY view) {
23+
grid.for_each_kxky([&](Dim kx, Dim ky) { view(kx, ky) *= factors(kx, ky); });
24+
}
25+
26+
HouLiFilterCached1D::HouLiFilterCached1D(Grid const &grid)
27+
: HouLiFilter(grid), factors_x(grid.KX), factors_y(grid.KY) {
28+
for (Dim kx = 0; kx < grid.KX; ++kx) {
29+
factors_x.at(kx) = exp(-36.0 * pow(kx_(kx) / grid.KX, 36.0));
30+
}
31+
for (Dim ky = 0; ky < grid.KY; ++ky) {
32+
factors_y.at(ky) = exp(-36.0 * pow(ky_(ky) / grid.KY, 36.0));
33+
}
34+
}
35+
36+
void HouLiFilterCached1D::operator()(Grid::View::C_XY view) {
37+
grid.for_each_kxky([&](Dim kx, Dim ky) {
38+
// Extra multiplication at runtime for lower memory cost
39+
view(kx, ky) *= factors_x[kx] * factors_y[ky];
40+
});
41+
}
42+
} // namespace ahr

cpp/lib/Filter.hpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#pragma once
2+
#include "constants.hpp"
3+
#include "grid.hpp"
4+
5+
namespace ahr {
6+
7+
class HouLiFilter {
8+
public:
9+
explicit HouLiFilter(Grid const &grid) : grid(grid) {}
10+
11+
void operator()(Grid::View::C_XY view);
12+
13+
protected:
14+
Grid const &grid;
15+
16+
// TODO extract to common utility
17+
[[nodiscard]] Real ky_(Dim ky) const {
18+
return (ky <= (grid.KY / 2) ? Real(ky) : Real(ky) - Real(grid.KY)) * Real(lx) / Real(ly);
19+
}
20+
[[nodiscard]] Real kx_(Dim kx) const { return Real(kx); }
21+
};
22+
23+
class HouLiFilterCached : HouLiFilter {
24+
public:
25+
explicit HouLiFilterCached(Grid const &grid);
26+
void operator()(Grid::View::C_XY view);
27+
28+
private:
29+
/// Pre-calculated factors for the Hou-Li filter.
30+
/// Note that this is a real buffer with dimensions (KX,KY)
31+
Grid::Buf::R_XY factors;
32+
};
33+
34+
class HouLiFilterCached1D : protected HouLiFilter {
35+
public:
36+
explicit HouLiFilterCached1D(Grid const &grid);
37+
void operator()(Grid::View::C_XY view);
38+
39+
protected:
40+
std::vector<Real> factors_x, factors_y;
41+
};
42+
} // namespace ahr

cpp/lib/Naive.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@ Naive::Naive(Dim M, Dim X, Dim Y)
1919
assert((Y & (Y - 1)) == 0);
2020
}
2121

22-
void Naive::hlFilter(View::C_XY &complexArray) {
23-
g.for_each_kxky([&](Dim kx, Dim ky) {
24-
complexArray(kx, ky) *=
25-
exp(-36.0 * pow(kx_(kx) / g.KX, 36.0)) * exp(-36.0 * pow(ky_(ky) / g.KY, 36.0));
26-
});
27-
}
28-
2922
void Naive::fftHL(View::R_XY in, View::C_XY out) {
3023
tf.fft(in, out);
3124
hlFilter(out);

cpp/lib/Naive.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "Exporter.hpp"
4+
#include "Filter.hpp"
45
#include "HermiteRunner.hpp"
56
#include "Transformer.hpp"
67
#include "constants.hpp"
@@ -35,6 +36,8 @@ class Naive : public ahr::HermiteRunner {
3536
Grid g;
3637
Transformer tf{g};
3738
Exporter exporter{g, tf};
39+
HouLiFilter hlFilter{g};
40+
3841
private:
3942

4043
using View = Grid::View;
@@ -45,7 +48,6 @@ class Naive : public ahr::HermiteRunner {
4548
Real dt{-1}; ///< timestep
4649
Real elapsedT{0.0}; ///< total time elapsed
4750

48-
void hlFilter(View::C_XY &complexArray);
4951
void fftHL(View::R_XY in, View::C_XY out); ///< FFT with Hou-Li Filter
5052

5153
Real bPerpMax{0};
@@ -127,8 +129,8 @@ class Naive : public ahr::HermiteRunner {
127129

128130
[[nodiscard]] Real ky_(Dim ky) const {
129131
return (ky <= (g.KY / 2) ? Real(ky) : Real(ky) - Real(g.KY)) * Real(lx) / Real(ly);
130-
};
131-
[[nodiscard]] Real kx_(Dim kx) const { return Real(kx); };
132+
}
133+
[[nodiscard]] Real kx_(Dim kx) const { return Real(kx); }
132134

133135
[[nodiscard]] Real kPerp2(Dim kx, Dim ky) const {
134136
auto dkx = kx_(kx), dky = ky_(ky);

0 commit comments

Comments
 (0)