Skip to content

Commit 2622bda

Browse files
authored
Merge pull request #2571 from zettai-reido/development
[oneMKL] matrix_mul_mkl update
2 parents 4dabbe1 + 9620f10 commit 2622bda

File tree

1 file changed

+85
-27
lines changed

1 file changed

+85
-27
lines changed

Libraries/oneMKL/matrix_mul_mkl/matrix_mul_mkl.cpp

Lines changed: 85 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
using namespace sycl;
2323

2424
template <typename T>
25-
void test(queue &Q, int M, int N, int K)
25+
static
26+
bool test(queue &Q, int M, int N, int K)
2627
{
27-
std::cout << "\nBenchmarking (" << M << " x " << K << ") x (" << K << " x " << N << ") matrix multiplication, " << type_string<T>() << std::endl;;
28+
std::cout << "\nBenchmarking (" << M << " x " << K << ") x (" << K << " x " << N << ") matrix multiplication, " << type_string<T>() << "\n";;
2829

2930
std::cout << " -> Initializing data...\n";
3031

@@ -38,7 +39,8 @@ void test(queue &Q, int M, int N, int K)
3839
auto C = malloc_device<T>(ldc * N, Q);
3940

4041
constexpr int rd_size = 1048576;
41-
auto host_data = malloc_host<T>(rd_size, Q);
42+
std::vector<T> host_vector(rd_size);
43+
auto host_data = host_vector.data();
4244

4345
/* Measure time for a given number of GEMM calls */
4446
auto time_gemms = [=, &Q](int runs) -> double {
@@ -74,10 +76,9 @@ void test(queue &Q, int M, int N, int K)
7476
}
7577
if (linear_id >= elems) break;
7678
}
77-
std::cout << (ok ? " passes." : " FAILS!") << std::endl;
78-
if (!ok) {
79-
exit(1);
80-
}
79+
80+
std::cout << "gemm " << (ok ? " passes." : " FAILS!") << " for type: " << type_string<T>() << "\n";
81+
if (!ok) { return false; }
8182

8283
/* Fill A/B with random data */
8384
generate_random_data(rd_size, host_data);
@@ -114,15 +115,17 @@ void test(queue &Q, int M, int N, int K)
114115
unit = 'P';
115116
}
116117

117-
std::cout << "\nAverage performance: " << flops << unit << 'F' << std::endl;
118+
std::cout << "\nAverage performance: " << flops << unit << 'F' << "\n";
118119

119120
/* Free data */
120-
free(A, Q);
121-
free(B, Q);
122121
free(C, Q);
123-
free(host_data, Q);
122+
free(B, Q);
123+
free(A, Q);
124+
125+
return true;
124126
}
125127

128+
static
126129
void usage(const char *pname)
127130
{
128131
std::cerr << "Usage:\n"
@@ -133,17 +136,37 @@ void usage(const char *pname)
133136
<< " double [default]\n"
134137
<< " single\n"
135138
<< " half\n"
139+
<< " all (runs all above)\n"
136140
<< "\n"
137141
<< "This benchmark uses the default DPC++ device, which can be controlled using\n"
138142
<< " the ONEAPI_DEVICE_SELECTOR environment variable\n";
139143
std::exit(1);
140144
}
141145

146+
static
147+
bool device_has_fp64(sycl::device const& D) {
148+
return (D.get_info<sycl::info::device::double_fp_config>().size() != 0);
149+
}
150+
151+
static
152+
void device_info(sycl::device const& D) {
153+
std::cout << "oneMKL DPC++ GEMM benchmark\n"
154+
<< "---------------------------\n"
155+
<< "Platform: " << D.get_platform().get_info<info::platform::name>() << "\n"
156+
<< "Device: " << D.get_info<info::device::name>() << "\n"
157+
<< "Driver_version: " << D.get_info<info::device::driver_version>() << "\n"
158+
<< "Core/EU count: " << D.get_info<info::device::max_compute_units>() << "\n"
159+
<< "Maximum clock frequency: " << D.get_info<info::device::max_clock_frequency>() << " MHz" << "\n"
160+
<< "FP64 capability: " << (device_has_fp64(D) ? "yes" : "no") << "\n"
161+
<< "\n"
162+
;
163+
}
164+
142165
int main(int argc, char **argv)
143166
{
144167
auto pname = argv[0];
145168
int M = 4096, N = 4096, K = 4096;
146-
std::string type = "double";
169+
std::string type = "none";
147170

148171
if (argc <= 1)
149172
usage(pname);
@@ -163,20 +186,55 @@ int main(int argc, char **argv)
163186
if (M <= 0 || N <= 0 || K <= 0)
164187
usage(pname);
165188

166-
queue Q;
189+
bool g_success = true;
190+
try {
191+
device D(default_selector_v);
192+
device_info(D);
167193

168-
std::cout << "oneMKL DPC++ GEMM benchmark\n"
169-
<< "---------------------------\n"
170-
<< "Device: " << Q.get_device().get_info<info::device::name>() << std::endl
171-
<< "Core/EU count: " << Q.get_device().get_info<info::device::max_compute_units>() << std::endl
172-
<< "Maximum clock frequency: " << Q.get_device().get_info<info::device::max_clock_frequency>() << " MHz" << std::endl;
173-
174-
if (type == "double")
175-
test<double>(Q, M, N, K);
176-
else if (type == "single" || type == "float")
177-
test<float>(Q, M, N, K);
178-
else if (type == "half")
179-
test<half>(Q, M, N, K);
180-
else
181-
usage(pname);
194+
context C(D);
195+
queue Q(C, D);
196+
197+
if ("none" == type)
198+
std::string type = device_has_fp64(D) ? "double" : "float";
199+
200+
if (type == "double") {
201+
if (device_has_fp64(D))
202+
test<double>(Q, M, N, K);
203+
else {
204+
std::cout << "no FP64 capability on given SYCL device and type == \"double\"";
205+
return 1;
206+
}
207+
}
208+
else if (type == "single" || type == "float")
209+
g_success = g_success && test<float>(Q, M, N, K);
210+
else if (type == "half")
211+
g_success = g_success && test<half>(Q, M, N, K);
212+
else if (type == "all") {
213+
type = "half";
214+
g_success = g_success && test<half>(Q, M, N, K);
215+
216+
type = "float";
217+
g_success = g_success && test<float>(Q, M, N, K);
218+
219+
if (device_has_fp64(D)) {
220+
type = "double";
221+
g_success = g_success && test<double>(Q, M, N, K);
222+
}
223+
} else {
224+
type = "none";
225+
usage(pname);
226+
}
227+
} catch (sycl::exception const& e) {
228+
std::cerr << "SYCL exception: " << e.what() << "\n";
229+
std::cerr << " while performing GEMM for"
230+
<< " M=" << M
231+
<< ", N=" << N
232+
<< ", K=" << K
233+
<< ", type `" << type << "`"
234+
<< "\n";
235+
return 139;
236+
}
237+
238+
return g_success ? 0 : 1;
182239
}
240+

0 commit comments

Comments
 (0)