Skip to content

Commit 7e19ee9

Browse files
committed
Use Catch2 for tests
1 parent 505d9c9 commit 7e19ee9

File tree

2 files changed

+14
-40
lines changed

2 files changed

+14
-40
lines changed

tests/CMakeLists.txt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
add_executable(test main.cpp)
2-
target_include_directories(test PRIVATE ${PROJECT_SOURCE_DIR}/include)
3-
set_target_properties(test PROPERTIES
4-
RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}
5-
)
1+
find_package(Catch2 3 REQUIRED)
2+
add_executable(tests tests.cpp)
3+
target_include_directories(tests PRIVATE ${PROJECT_SOURCE_DIR}/include)
4+
target_link_libraries(tests PRIVATE Catch2::Catch2WithMain)
5+
6+
include(CTest)
7+
include(Catch)
8+
catch_discover_tests(tests)
Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,16 @@
1+
#include <catch2/catch_test_macros.hpp>
12
#include <random>
23
#include <bit>
3-
#include <cassert>
44
#include <vector>
55

66
#include "gemmi.hpp"
77
#include "utilities.hpp"
88

9-
int main() {
10-
11-
// typedef float my_fp_type;
12-
// typedef uint32_t my_int_type;
9+
TEST_CASE("GEMMI accuracy", "[gemmi]") {
1310

1411
typedef double my_fp_type;
1512

16-
size_t ms = 2, ns = 2, ps = 2;
17-
std::vector<my_fp_type> As(ms * ps);
18-
std::vector<my_fp_type> Bs(ps * ns);
19-
std::default_random_engine generator(std::random_device{}());
20-
std::uniform_real_distribution<double> distribution(-100000.0, 100000.0);
21-
for (auto & element : As)
22-
element = distribution(generator);
23-
for (auto & element : Bs)
24-
element = distribution(generator);
25-
As[0] = 1.984375; // 0x3FFE0000 -> 1.11111 10000 00000 00000 000
26-
As[1] = 1.999969482421875; // 0x3FFFFF00 -> 1.11111 11111 11111 00000 000
27-
As[2] = 1.99993896484375; // 0x3FFFFE00 -> 1.11111 11111 11110 00000 000
28-
As[3] = 1.9998779296875; // 0x3FFFCE00 -> 1.11111 11111 11100 00000 000
29-
30-
auto Cs = gemmi<my_fp_type, int8_t, int32_t>(As, Bs, ms, ps, ns, 10);
31-
auto Cs_ref = reference_gemm(As, Bs, ms, ps, ns);
32-
33-
double relErr = frobenius_norm<my_fp_type, double>(Cs - Cs_ref) / frobenius_norm<my_fp_type, double>(Cs);
34-
35-
std::cout << "Relative error: " << relErr << std::endl;
36-
assert(relErr < 1e-15);
37-
38-
// Test different sizes.
13+
// Test different sizes
3914
for (size_t numSplitA : { 1, 2, 10 }) {
4015
for (size_t numSplitB : { 1, 2, 10 }) {
4116
for (size_t m = 10; m <= 50; m += 10) {
@@ -44,27 +19,23 @@ int main() {
4419
std::vector<my_fp_type> A(m * p);
4520
std::vector<my_fp_type> B(p * n);
4621

47-
std::cout << "m: " << m << ", p: " << p << ", n: " << n << std::endl;
48-
4922
// Initalize matrix with random values.
5023
std::default_random_engine generator(std::random_device{}());
5124
std::uniform_real_distribution<double> distribution(-100000.0, 100000.0);
5225
for (auto & element : A)
5326
element = numSplitA < 10 ? ldexp(1.0, 2 * numSplitA) - 1 : distribution(generator);
5427
for (auto & element : B)
5528
element = numSplitB < 10 ? ldexp(1.0, 2 * numSplitB) - 1 : distribution(generator);
29+
5630
auto C = gemmi<my_fp_type, int8_t, int32_t>(A, B, m, p, n, numSplitA, numSplitB);
5731
auto C_ref = reference_gemm(A, B, m, p, n);
5832

5933
double relative_error = frobenius_norm<my_fp_type, double>(C - C_ref) / frobenius_norm<my_fp_type, double>(C);
6034

61-
std::cout << "Relative error: " << relative_error << std::endl;
62-
assert(relative_error < 1e-15);
35+
REQUIRE(relative_error < 1e-15);
6336
}
6437
}
6538
}
6639
}
6740
}
68-
69-
return 0;
70-
}
41+
}

0 commit comments

Comments
 (0)