@@ -37,13 +37,8 @@ void test(queue &Q, int M, int N, int K)
37
37
auto B = malloc_device<T>(ldb * N, Q);
38
38
auto C = malloc_device<T>(ldc * N, Q);
39
39
40
- /* Fill A/B with random data */
41
40
constexpr int rd_size = 1048576 ;
42
- auto random_data = malloc_host<T>(rd_size, Q);
43
- generate_random_data (rd_size, random_data);
44
-
45
- replicate_data (Q, A, lda * K, random_data, rd_size);
46
- replicate_data (Q, B, ldb * N, random_data, rd_size);
41
+ auto host_data = malloc_host<T>(rd_size, Q);
47
42
48
43
/* Measure time for a given number of GEMM calls */
49
44
auto time_gemms = [=, &Q](int runs) -> double {
@@ -57,7 +52,36 @@ void test(queue &Q, int M, int N, int K)
57
52
return duration<double >(end - start).count ();
58
53
};
59
54
60
- /* Do a warmup call to initialize MKL and ensure kernels are JIT'ed if needed */
55
+ /* Fill A/B with all ones to verify correctness */
56
+ generate_ones (rd_size, host_data);
57
+ replicate_data (Q, A, lda * K, host_data, rd_size);
58
+ replicate_data (Q, B, ldb * N, host_data, rd_size);
59
+
60
+ /* Verify that the leading entries of C are correct */
61
+ std::cout << " -> Verification..." ;
62
+ (void ) time_gemms (1 );
63
+ size_t elems = std::min (ldc * N, rd_size);
64
+ Q.copy (C, host_data, elems).wait ();
65
+ bool ok = true ;
66
+ int linear_id = 0 ;
67
+ for (size_t j = 0 ; j < N; j++) {
68
+ for (size_t i = 0 ; i < M; i++) {
69
+ linear_id = j*ldc + i;
70
+ if (linear_id >= elems) break ;
71
+ if (host_data[linear_id] != T (K)) {
72
+ ok = false ;
73
+ }
74
+ }
75
+ if (linear_id >= elems) break ;
76
+ }
77
+ std::cout << (ok ? " passes." : " FAILS!" ) << std::endl;
78
+
79
+ /* Fill A/B with random data */
80
+ generate_random_data (rd_size, host_data);
81
+ replicate_data (Q, A, lda * K, host_data, rd_size);
82
+ replicate_data (Q, B, ldb * N, host_data, rd_size);
83
+
84
+ /* Do a warmup call with random data to initialize MKL and ensure kernels are JIT'ed if needed */
61
85
std::cout << " -> Warmup...\n " ;
62
86
(void ) time_gemms (1 );
63
87
@@ -93,7 +117,7 @@ void test(queue &Q, int M, int N, int K)
93
117
free (A, Q);
94
118
free (B, Q);
95
119
free (C, Q);
96
- free (random_data , Q);
120
+ free (host_data , Q);
97
121
}
98
122
99
123
void usage (const char *pname)
0 commit comments