22#include < random>
33#include < bit>
44#include < vector>
5+ #include < iostream>
56
67#include " gemmi.hpp"
78#include " utilities.hpp"
89
9- TEST_CASE ( " GEMMI accuracy " , " [gemmi] " ) {
10-
11- typedef double my_fp_type;
10+ template < typename fp_t > double tolerance ( ) {return 0 ;}
11+ template <> double tolerance< float >() { return 1e-6 ;}
12+ template <> double tolerance< double >() { return 1e-15 ;}
1213
14+ template <typename fp_t >
15+ void runTest () {
1316 // Test different sizes
14- for (auto splitType : {splittingStrategy::bitMasking, splittingStrategy::roundToNearest}) {
17+ for (auto splitType : {splittingStrategy::bitMasking,splittingStrategy::roundToNearest}) {
1518 for (auto accumulationType : {accumulationStrategy::floatingPoint, accumulationStrategy::integer}) {
1619 for (size_t numSplitA : { 1 , 2 , 10 }) {
1720 for (size_t numSplitB : { 1 , 2 , 10 }) {
1821 for (size_t m = 10 ; m <= 50 ; m += 10 ) {
1922 for (size_t p = 10 ; p <= 50 ; p += 10 ) {
2023 for (size_t n = 10 ; n <= 50 ; n += 10 ) {
21- std::vector<my_fp_type > A (m * p);
22- std::vector<my_fp_type > B (p * n);
24+ std::vector<fp_t > A (m * p);
25+ std::vector<fp_t > B (p * n);
2326
2427 // Initalize matrix with random values.
2528 std::default_random_engine generator (std::random_device{}());
26- std::uniform_real_distribution<double > distribution (-100000.0 , 100000.0 );
29+ std::uniform_real_distribution<fp_t > distribution (-100000.0 , 100000.0 );
2730 for (auto & element : A)
2831 element = numSplitA < 10 ? ldexp (1.0 , 2 * numSplitA) - 1 : distribution (generator);
2932 for (auto & element : B)
3033 element = numSplitB < 10 ? ldexp (1.0 , 2 * numSplitB) - 1 : distribution (generator);
3134
32- auto C = gemmi<my_fp_type , int8_t , int32_t >(A, B, m, p, n, numSplitA, numSplitB, splitType, accumulationType);
35+ auto C = gemmi<fp_t , int8_t , int32_t >(A, B, m, p, n, numSplitA, numSplitB, splitType, accumulationType);
3336 auto C_ref = reference_gemm (A, B, m, p, n);
3437
35- double relative_error = frobenius_norm<my_fp_type , double >(C - C_ref) / frobenius_norm<my_fp_type , double >(C);
38+ double relative_error = frobenius_norm<fp_t , double >(C - C_ref) / frobenius_norm<fp_t , double >(C);
3639
37- REQUIRE (relative_error < 1e-15 );
40+ REQUIRE (relative_error < tolerance< fp_t >() );
3841 }
3942 }
4043 }
4144 }
4245 }
4346 }
4447 }
48+ }
49+
50+ TEST_CASE (" GEMMI accuracy binary64" , " [gemmi]" ) {
51+ runTest<double >();
52+ }
53+
54+ TEST_CASE (" GEMMI accuracy binary32" , " [gemmi]" ) {
55+ runTest<float >();
4556}
0 commit comments