Skip to content

Commit 7f0c812

Browse files
committed
Rough verification for matrix_mul_mkl.cpp
1 parent 631dc54 commit 7f0c812

File tree

2 files changed

+55
-8
lines changed

2 files changed

+55
-8
lines changed

Libraries/oneMKL/matrix_mul_mkl/matrix_mul_mkl.cpp

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,10 @@ void test(queue &Q, int M, int N, int K)
3636
auto A = malloc_device<T>(lda * K, Q);
3737
auto B = malloc_device<T>(ldb * N, Q);
3838
auto C = malloc_device<T>(ldc * N, Q);
39+
auto flag = malloc_shared<int>(1, Q);
3940

40-
/* Fill A/B with random data */
4141
constexpr int rd_size = 1048576;
42-
auto random_data = malloc_host<T>(rd_size, Q);
43-
generate_random_data(rd_size, random_data);
44-
45-
replicate_data(Q, A, lda * K, random_data, rd_size);
46-
replicate_data(Q, B, ldb * N, random_data, rd_size);
42+
auto host_data = malloc_host<T>(rd_size, Q);
4743

4844
/* Measure time for a given number of GEMM calls */
4945
auto time_gemms = [=, &Q](int runs) -> double {
@@ -57,7 +53,49 @@ void test(queue &Q, int M, int N, int K)
5753
return duration<double>(end - start).count();
5854
};
5955

60-
/* Do a warmup call to initialize MKL and ensure kernels are JIT'ed if needed */
56+
/* Fill A/B with all ones to verify correctness */
57+
generate_ones(rd_size, host_data);
58+
replicate_data(Q, A, lda * K, host_data, rd_size);
59+
replicate_data(Q, B, ldb * N, host_data, rd_size);
60+
61+
/* Verify that the leading entries of C are correct */
62+
std::cout << " -> Verification...\n";
63+
(void) time_gemms(1);
64+
size_t elems = std::min(ldc * N, rd_size);
65+
Q.copy(C, host_data, elems);
66+
flag[0] = 0;
67+
int linear_id = 0;
68+
for (size_t j = 0; j < N; j++) {
69+
for (size_t i = 0; i < M; i++) {
70+
linear_id = j*ldc + i;
71+
if (linear_id >= elems) break;
72+
if (host_data[linear_id] != T(K)) {
73+
flag[0] = 1;
74+
}
75+
}
76+
if (linear_id >= elems) break;
77+
}
78+
/*
79+
for (size_t i = 0; i < elems; i++) {
80+
int count = 0;
81+
if (host_data[i] != T(K)) {
82+
flag[0] = 1;
83+
if (count < 10) {
84+
sycl::ext::oneapi::experimental::printf("error elem %d expect %f got %f\n",
85+
i, T(K), host_data[i]);
86+
count++;
87+
}
88+
}
89+
}
90+
*/
91+
std::cout << " verification " << (flag[0] == 0 ? "passes." : "FAILS!") << std::endl;
92+
93+
/* Fill A/B with random data */
94+
generate_random_data(rd_size, host_data);
95+
replicate_data(Q, A, lda * K, host_data, rd_size);
96+
replicate_data(Q, B, ldb * N, host_data, rd_size);
97+
98+
/* Do a warmup call with random data to initialize MKL and ensure kernels are JIT'ed if needed */
6199
std::cout << " -> Warmup...\n";
62100
(void) time_gemms(1);
63101

@@ -93,7 +131,8 @@ void test(queue &Q, int M, int N, int K)
93131
free(A, Q);
94132
free(B, Q);
95133
free(C, Q);
96-
free(random_data, Q);
134+
free(flag, Q);
135+
free(host_data, Q);
97136
}
98137

99138
void usage(const char *pname)

Libraries/oneMKL/matrix_mul_mkl/utilities.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ int nice_ld(int x)
2424
return x;
2525
}
2626

27+
template <typename T>
28+
void generate_ones(size_t elems, T *v)
29+
{
30+
#pragma omp parallel for
31+
for (size_t i = 0; i < elems; i++)
32+
v[i] = T(1);
33+
}
34+
2735
/* Random number generation helpers */
2836
template <typename T>
2937
void generate_random_data(size_t elems, T *v)

0 commit comments

Comments
 (0)