1+ #include < catch2/catch_test_macros.hpp>
12#include < random>
23#include < bit>
3- #include < cassert>
44#include < vector>
55
66#include " gemmi.hpp"
77#include " utilities.hpp"
88
9- int main () {
10-
11- // typedef float my_fp_type;
12- // typedef uint32_t my_int_type;
9+ TEST_CASE (" GEMMI accuracy" , " [gemmi]" ) {
1310
1411 typedef double my_fp_type;
1512
16- size_t ms = 2 , ns = 2 , ps = 2 ;
17- std::vector<my_fp_type> As (ms * ps);
18- std::vector<my_fp_type> Bs (ps * ns);
19- std::default_random_engine generator (std::random_device{}());
20- std::uniform_real_distribution<double > distribution (-100000.0 , 100000.0 );
21- for (auto & element : As)
22- element = distribution (generator);
23- for (auto & element : Bs)
24- element = distribution (generator);
25- As[0 ] = 1.984375 ; // 0x3FFE0000 -> 1.11111 10000 00000 00000 000
26- As[1 ] = 1.999969482421875 ; // 0x3FFFFF00 -> 1.11111 11111 11111 00000 000
27- As[2 ] = 1.99993896484375 ; // 0x3FFFFE00 -> 1.11111 11111 11110 00000 000
28- As[3 ] = 1.9998779296875 ; // 0x3FFFCE00 -> 1.11111 11111 11100 00000 000
29-
30- auto Cs = gemmi<my_fp_type, int8_t , int32_t >(As, Bs, ms, ps, ns, 10 );
31- auto Cs_ref = reference_gemm (As, Bs, ms, ps, ns);
32-
33- double relErr = frobenius_norm<my_fp_type, double >(Cs - Cs_ref) / frobenius_norm<my_fp_type, double >(Cs);
34-
35- std::cout << " Relative error: " << relErr << std::endl;
36- assert (relErr < 1e-15 );
37-
38- // Test different sizes.
13+ // Test different sizes
3914 for (size_t numSplitA : { 1 , 2 , 10 }) {
4015 for (size_t numSplitB : { 1 , 2 , 10 }) {
4116 for (size_t m = 10 ; m <= 50 ; m += 10 ) {
@@ -44,27 +19,23 @@ int main() {
4419 std::vector<my_fp_type> A (m * p);
4520 std::vector<my_fp_type> B (p * n);
4621
47- std::cout << " m: " << m << " , p: " << p << " , n: " << n << std::endl;
48-
4922 // Initalize matrix with random values.
5023 std::default_random_engine generator (std::random_device{}());
5124 std::uniform_real_distribution<double > distribution (-100000.0 , 100000.0 );
5225 for (auto & element : A)
5326 element = numSplitA < 10 ? ldexp (1.0 , 2 * numSplitA) - 1 : distribution (generator);
5427 for (auto & element : B)
5528 element = numSplitB < 10 ? ldexp (1.0 , 2 * numSplitB) - 1 : distribution (generator);
29+
5630 auto C = gemmi<my_fp_type, int8_t , int32_t >(A, B, m, p, n, numSplitA, numSplitB);
5731 auto C_ref = reference_gemm (A, B, m, p, n);
5832
5933 double relative_error = frobenius_norm<my_fp_type, double >(C - C_ref) / frobenius_norm<my_fp_type, double >(C);
6034
61- std::cout << " Relative error: " << relative_error << std::endl;
62- assert (relative_error < 1e-15 );
35+ REQUIRE (relative_error < 1e-15 );
6336 }
6437 }
6538 }
6639 }
6740 }
68-
69- return 0 ;
70- }
41+ }
0 commit comments