Skip to content

Commit f7b0bb6

Browse files
committed
filter tests
1 parent ffcb60f commit f7b0bb6

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

cpp/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ function(add_viriato_test FILE)
1818
add_viriato_test_impl(${NAME} ${FILE} ${ARGN})
1919
endfunction()
2020

21+
add_viriato_test(exporter.cpp)
22+
add_viriato_test(filter.cpp)
2123
add_viriato_test(grid.cpp)
2224
add_viriato_test(transformer.cpp)
23-
add_viriato_test(exporter.cpp)
2425

2526
add_viriato_test(naive-energies.cpp)
2627
add_viriato_test(naive-moments.cpp)

cpp/test/filter.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include "Filter.hpp"
2+
#include "grid.hpp"
3+
4+
#include "util.hpp"
5+
6+
#include <gtest/gtest.h>
7+
8+
namespace ahr {
9+
10+
template <typename TestedFilter> class TestFilter : public ::testing::Test {
11+
protected:
12+
Grid grid{5, 9, 16};
13+
HouLiFilter filter{grid};
14+
TestedFilter filter_t{grid};
15+
};
16+
17+
using Types = ::testing::Types<HouLiFilterCached, HouLiFilterCached1D>;
18+
TYPED_TEST_SUITE(TestFilter, Types);
19+
20+
TYPED_TEST(TestFilter, Filter) {
21+
auto buf = this->grid.cBufXY();
22+
auto buf_t = this->grid.cBufXY();
23+
24+
// initialize buffers
25+
this->grid.for_each_kxky([&](Dim kx, Dim ky) {
26+
buf(kx, ky) = {std::sin(2 * pi * kx), std::cos(2 * pi * ky)};
27+
buf_t(kx, ky) = {std::sin(2 * pi * kx), std::cos(2 * pi * ky)};
28+
});
29+
30+
this->filter(buf);
31+
this->filter_t(buf_t);
32+
33+
ASSERT_THAT(buf_t.to_mdspan(), MdspanElementsAllClose(buf.to_mdspan(), 1e-16, 1e-15));
34+
}
35+
36+
} // namespace ahr

cpp/test/util.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
#include <fftw-cpp/fftw-cpp.h>
44
#include <gmock/gmock.h>
5-
#include <tuple>
65
#include <iomanip>
6+
#include <tuple>
77

88
// Function to slice the tuple
99
template <std::size_t Start, std::size_t End, typename Tuple, std::size_t... Indices>
@@ -21,11 +21,19 @@ template <std::size_t Start, typename... Args> auto slice(const std::tuple<Args.
2121
return slice<Start, sizeof...(Args)>(t);
2222
}
2323

24+
template <typename T> T absdiff_no_overflow(T a, T b) { return std::max(a, b) - std::min(a, b); }
25+
26+
template <typename T> T absdiff_no_overflow(std::complex<T> a, std::complex<T> b) {
27+
auto const real_diff = absdiff_no_overflow(a.real(), b.real());
28+
auto const imag_diff = absdiff_no_overflow(a.imag(), b.imag());
29+
return std::abs(std::complex<T>{real_diff, imag_diff});
30+
}
31+
2432
using ::testing::PrintToString;
2533
MATCHER_P3(AllClose, val, rel_tol, abs_tol,
2634
PrintToString(val) + " ±" + PrintToString(abs_tol) + " (±" +
2735
PrintToString(double(rel_tol) * std::abs(val)) + ")") {
28-
auto diff = std::max(arg, val) - std::min(arg, val);
36+
auto diff = absdiff_no_overflow(val, arg);
2937
double tolerance_diff = double(diff) - double(abs_tol) - double(rel_tol) * std::abs(val);
3038
return tolerance_diff <= 0;
3139
}

0 commit comments

Comments
 (0)