Skip to content

Commit 14d1f41

Browse files
committed
Add binary32 tests
1 parent 71200bf commit 14d1f41

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

.github/workflows/run_cpp_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ jobs:
6464
- uses: codecov/codecov-action@v5
6565
if: runner.os == 'Linux'
6666
with:
67+
disable_search: true
6768
files: build/coverage.xml
6869
verbose: true
6970
env:

tests/tests.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,55 @@
22
#include <random>
33
#include <bit>
44
#include <vector>
5+
#include <iostream>
56

67
#include "gemmi.hpp"
78
#include "utilities.hpp"
89

9-
TEST_CASE("GEMMI accuracy", "[gemmi]") {
10-
11-
typedef double my_fp_type;
10+
template <typename fp_t> double tolerance() {return 0;}
11+
template <> double tolerance<float>() {return 1e-6;}
12+
template <> double tolerance<double>() {return 1e-15;}
1213

14+
template <typename fp_t>
15+
void runTest() {
1316
// Test different sizes
14-
for (auto splitType : {splittingStrategy::bitMasking, splittingStrategy::roundToNearest}) {
17+
for (auto splitType : {splittingStrategy::bitMasking,splittingStrategy::roundToNearest}) {
1518
for (auto accumulationType : {accumulationStrategy::floatingPoint, accumulationStrategy::integer}) {
1619
for (size_t numSplitA : { 1, 2, 10 }) {
1720
for (size_t numSplitB : { 1, 2, 10 }) {
1821
for (size_t m = 10; m <= 50; m += 10) {
1922
for (size_t p = 10; p <= 50; p += 10) {
2023
for (size_t n = 10; n <= 50; n += 10) {
21-
std::vector<my_fp_type> A(m * p);
22-
std::vector<my_fp_type> B(p * n);
24+
std::vector<fp_t> A(m * p);
25+
std::vector<fp_t> B(p * n);
2326

2427
// Initalize matrix with random values.
2528
std::default_random_engine generator(std::random_device{}());
26-
std::uniform_real_distribution<double> distribution(-100000.0, 100000.0);
29+
std::uniform_real_distribution<fp_t> distribution(-100000.0, 100000.0);
2730
for (auto & element : A)
2831
element = numSplitA < 10 ? ldexp(1.0, 2 * numSplitA) - 1 : distribution(generator);
2932
for (auto & element : B)
3033
element = numSplitB < 10 ? ldexp(1.0, 2 * numSplitB) - 1 : distribution(generator);
3134

32-
auto C = gemmi<my_fp_type, int8_t, int32_t>(A, B, m, p, n, numSplitA, numSplitB, splitType, accumulationType);
35+
auto C = gemmi<fp_t, int8_t, int32_t>(A, B, m, p, n, numSplitA, numSplitB, splitType, accumulationType);
3336
auto C_ref = reference_gemm(A, B, m, p, n);
3437

35-
double relative_error = frobenius_norm<my_fp_type, double>(C - C_ref) / frobenius_norm<my_fp_type, double>(C);
38+
double relative_error = frobenius_norm<fp_t, double>(C - C_ref) / frobenius_norm<fp_t, double>(C);
3639

37-
REQUIRE(relative_error < 1e-15);
40+
REQUIRE(relative_error < tolerance<fp_t>());
3841
}
3942
}
4043
}
4144
}
4245
}
4346
}
4447
}
48+
}
49+
50+
TEST_CASE("GEMMI accuracy binary64", "[gemmi]") {
51+
runTest<double>();
52+
}
53+
54+
TEST_CASE("GEMMI accuracy binary32", "[gemmi]") {
55+
runTest<float>();
4556
}

0 commit comments

Comments
 (0)