22
22
using namespace sycl ;
23
23
24
24
template <typename T>
25
- void test (queue &Q, int M, int N, int K)
25
+ static
26
+ bool test (queue &Q, int M, int N, int K)
26
27
{
27
- std::cout << " \n Benchmarking (" << M << " x " << K << " ) x (" << K << " x " << N << " ) matrix multiplication, " << type_string<T>() << std::endl ;;
28
+ std::cout << " \n Benchmarking (" << M << " x " << K << " ) x (" << K << " x " << N << " ) matrix multiplication, " << type_string<T>() << " \n " ;;
28
29
29
30
std::cout << " -> Initializing data...\n " ;
30
31
@@ -38,7 +39,8 @@ void test(queue &Q, int M, int N, int K)
38
39
auto C = malloc_device<T>(ldc * N, Q);
39
40
40
41
constexpr int rd_size = 1048576 ;
41
- auto host_data = malloc_host<T>(rd_size, Q);
42
+ std::vector<T> host_vector (rd_size);
43
+ auto host_data = host_vector.data ();
42
44
43
45
/* Measure time for a given number of GEMM calls */
44
46
auto time_gemms = [=, &Q](int runs) -> double {
@@ -74,10 +76,9 @@ void test(queue &Q, int M, int N, int K)
74
76
}
75
77
if (linear_id >= elems) break ;
76
78
}
77
- std::cout << (ok ? " passes." : " FAILS!" ) << std::endl;
78
- if (!ok) {
79
- exit (1 );
80
- }
79
+
80
+ std::cout << " gemm " << (ok ? " passes." : " FAILS!" ) << " for type: " << type_string<T>() << " \n " ;
81
+ if (!ok) { return false ; }
81
82
82
83
/* Fill A/B with random data */
83
84
generate_random_data (rd_size, host_data);
@@ -114,15 +115,17 @@ void test(queue &Q, int M, int N, int K)
114
115
unit = ' P' ;
115
116
}
116
117
117
- std::cout << " \n Average performance: " << flops << unit << ' F' << std::endl ;
118
+ std::cout << " \n Average performance: " << flops << unit << ' F' << " \n " ;
118
119
119
120
/* Free data */
120
- free (A, Q);
121
- free (B, Q);
122
121
free (C, Q);
123
- free (host_data, Q);
122
+ free (B, Q);
123
+ free (A, Q);
124
+
125
+ return true ;
124
126
}
125
127
128
+ static
126
129
void usage (const char *pname)
127
130
{
128
131
std::cerr << " Usage:\n "
@@ -133,17 +136,37 @@ void usage(const char *pname)
133
136
<< " double [default]\n "
134
137
<< " single\n "
135
138
<< " half\n "
139
+ << " all (runs all above)\n "
136
140
<< " \n "
137
141
<< " This benchmark uses the default DPC++ device, which can be controlled using\n "
138
142
<< " the ONEAPI_DEVICE_SELECTOR environment variable\n " ;
139
143
std::exit (1 );
140
144
}
141
145
146
+ static
147
+ bool device_has_fp64 (sycl::device const & D) {
148
+ return (D.get_info <sycl::info::device::double_fp_config>().size () != 0 );
149
+ }
150
+
151
+ static
152
+ void device_info (sycl::device const & D) {
153
+ std::cout << " oneMKL DPC++ GEMM benchmark\n "
154
+ << " ---------------------------\n "
155
+ << " Platform: " << D.get_platform ().get_info <info::platform::name>() << " \n "
156
+ << " Device: " << D.get_info <info::device::name>() << " \n "
157
+ << " Driver_version: " << D.get_info <info::device::driver_version>() << " \n "
158
+ << " Core/EU count: " << D.get_info <info::device::max_compute_units>() << " \n "
159
+ << " Maximum clock frequency: " << D.get_info <info::device::max_clock_frequency>() << " MHz" << " \n "
160
+ << " FP64 capability: " << (device_has_fp64 (D) ? " yes" : " no" ) << " \n "
161
+ << " \n "
162
+ ;
163
+ }
164
+
142
165
int main (int argc, char **argv)
143
166
{
144
167
auto pname = argv[0 ];
145
168
int M = 4096 , N = 4096 , K = 4096 ;
146
- std::string type = " double " ;
169
+ std::string type = " none " ;
147
170
148
171
if (argc <= 1 )
149
172
usage (pname);
@@ -163,20 +186,55 @@ int main(int argc, char **argv)
163
186
if (M <= 0 || N <= 0 || K <= 0 )
164
187
usage (pname);
165
188
166
- queue Q;
189
+ bool g_success = true ;
190
+ try {
191
+ device D (default_selector_v);
192
+ device_info (D);
167
193
168
- std::cout << " oneMKL DPC++ GEMM benchmark\n "
169
- << " ---------------------------\n "
170
- << " Device: " << Q.get_device ().get_info <info::device::name>() << std::endl
171
- << " Core/EU count: " << Q.get_device ().get_info <info::device::max_compute_units>() << std::endl
172
- << " Maximum clock frequency: " << Q.get_device ().get_info <info::device::max_clock_frequency>() << " MHz" << std::endl;
173
-
174
- if (type == " double" )
175
- test<double >(Q, M, N, K);
176
- else if (type == " single" || type == " float" )
177
- test<float >(Q, M, N, K);
178
- else if (type == " half" )
179
- test<half>(Q, M, N, K);
180
- else
181
- usage (pname);
194
+ context C (D);
195
+ queue Q (C, D);
196
+
197
+ if (" none" == type)
198
+ std::string type = device_has_fp64 (D) ? " double" : " float" ;
199
+
200
+ if (type == " double" ) {
201
+ if (device_has_fp64 (D))
202
+ test<double >(Q, M, N, K);
203
+ else {
204
+ std::cout << " no FP64 capability on given SYCL device and type == \" double\" " ;
205
+ return 1 ;
206
+ }
207
+ }
208
+ else if (type == " single" || type == " float" )
209
+ g_success = g_success && test<float >(Q, M, N, K);
210
+ else if (type == " half" )
211
+ g_success = g_success && test<half>(Q, M, N, K);
212
+ else if (type == " all" ) {
213
+ type = " half" ;
214
+ g_success = g_success && test<half>(Q, M, N, K);
215
+
216
+ type = " float" ;
217
+ g_success = g_success && test<float >(Q, M, N, K);
218
+
219
+ if (device_has_fp64 (D)) {
220
+ type = " double" ;
221
+ g_success = g_success && test<double >(Q, M, N, K);
222
+ }
223
+ } else {
224
+ type = " none" ;
225
+ usage (pname);
226
+ }
227
+ } catch (sycl::exception const & e) {
228
+ std::cerr << " SYCL exception: " << e.what () << " \n " ;
229
+ std::cerr << " while performing GEMM for"
230
+ << " M=" << M
231
+ << " , N=" << N
232
+ << " , K=" << K
233
+ << " , type `" << type << " `"
234
+ << " \n " ;
235
+ return 139 ;
236
+ }
237
+
238
+ return g_success ? 0 : 1 ;
182
239
}
240
+
0 commit comments