Skip to content

Commit 9d9dd83

Browse files
committed
[tests][modernize]Make the tests compile with hipSYCL
*Changed e.get_cl_code to e.what *Fixed namespace sycl:: to cl::sycl *Gave explict type to exception_handler
1 parent 1ed12c7 commit 9d9dd83

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

114 files changed

+348
-345
lines changed

tests/unit_tests/blas/batch/axpy_batch_usm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ namespace {
4646
template <typename fp>
4747
int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
4848
// Catch asynchronous exceptions.
49-
auto exception_handler = [](exception_list exceptions) {
49+
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
5050
for (std::exception_ptr const &e : exceptions) {
5151
try {
5252
std::rethrow_exception(e);
5353
}
5454
catch (exception const &e) {
5555
std::cout << "Caught asynchronous SYCL exception during AXPY_BATCH:\n"
5656
<< e.what() << std::endl
57-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
57+
<< "OpenCL status: " << e.what() << std::endl;
5858
}
5959
}
6060
};
@@ -186,7 +186,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
186186
catch (exception const &e) {
187187
std::cout << "Caught synchronous SYCL exception during AXPY_BATCH:\n"
188188
<< e.what() << std::endl
189-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
189+
<< "OpenCL status: " << e.what() << std::endl;
190190
}
191191

192192
catch (const oneapi::mkl::unimplemented &e) {

tests/unit_tests/blas/batch/gemm_batch_stride.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
128128
// Call DPC++ GEMM_BATCH_STRIDE.
129129

130130
// Catch asynchronous exceptions.
131-
auto exception_handler = [](exception_list exceptions) {
131+
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
132132
for (std::exception_ptr const &e : exceptions) {
133133
try {
134134
std::rethrow_exception(e);
135135
}
136136
catch (exception const &e) {
137137
std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
138138
<< e.what() << std::endl
139-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
139+
<< "OpenCL status: " << e.what() << std::endl;
140140
}
141141
}
142142
};
@@ -181,7 +181,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
181181
catch (exception const &e) {
182182
std::cout << "Caught synchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
183183
<< e.what() << std::endl
184-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
184+
<< "OpenCL status: " << e.what() << std::endl;
185185
}
186186

187187
catch (const oneapi::mkl::unimplemented &e) {

tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ namespace {
4646
template <typename fp>
4747
int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
4848
// Catch asynchronous exceptions.
49-
auto exception_handler = [](exception_list exceptions) {
49+
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
5050
for (std::exception_ptr const &e : exceptions) {
5151
try {
5252
std::rethrow_exception(e);
5353
}
5454
catch (exception const &e) {
5555
std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
5656
<< e.what() << std::endl
57-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
57+
<< "OpenCL status: " << e.what() << std::endl;
5858
}
5959
}
6060
};
@@ -208,7 +208,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
208208
catch (exception const &e) {
209209
std::cout << "Caught synchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
210210
<< e.what() << std::endl
211-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
211+
<< "OpenCL status: " << e.what() << std::endl;
212212
}
213213

214214
catch (const oneapi::mkl::unimplemented &e) {

tests/unit_tests/blas/batch/gemm_batch_usm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ namespace {
4646
template <typename fp>
4747
int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
4848
// Catch asynchronous exceptions.
49-
auto exception_handler = [](exception_list exceptions) {
49+
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
5050
for (std::exception_ptr const &e : exceptions) {
5151
try {
5252
std::rethrow_exception(e);
5353
}
5454
catch (exception const &e) {
5555
std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH:\n"
5656
<< e.what() << std::endl
57-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
57+
<< "OpenCL status: " << e.what() << std::endl;
5858
}
5959
}
6060
};
@@ -263,7 +263,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
263263
catch (exception const &e) {
264264
std::cout << "Caught synchronous SYCL exception during GEMM_BATCH:\n"
265265
<< e.what() << std::endl
266-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
266+
<< "OpenCL status: " << e.what() << std::endl;
267267
}
268268

269269
catch (const oneapi::mkl::unimplemented &e) {

tests/unit_tests/blas/batch/trsm_batch_stride.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,15 @@ int test(device *dev, oneapi::mkl::layout layout) {
121121
// Call DPC++ TRSM_BATCH_STRIDE.
122122

123123
// Catch asynchronous exceptions.
124-
auto exception_handler = [](exception_list exceptions) {
124+
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
125125
for (std::exception_ptr const &e : exceptions) {
126126
try {
127127
std::rethrow_exception(e);
128128
}
129129
catch (exception const &e) {
130130
std::cout << "Caught asynchronous SYCL exception during TRSM_BATCH_STRIDE:\n"
131131
<< e.what() << std::endl
132-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
132+
<< "OpenCL status: " << e.what() << std::endl;
133133
}
134134
}
135135
};
@@ -173,7 +173,7 @@ int test(device *dev, oneapi::mkl::layout layout) {
173173
catch (exception const &e) {
174174
std::cout << "Caught synchronous SYCL exception during TRSM_BATCH_STRIDE:\n"
175175
<< e.what() << std::endl
176-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
176+
<< "OpenCL status: " << e.what() << std::endl;
177177
}
178178

179179
catch (const oneapi::mkl::unimplemented &e) {

tests/unit_tests/blas/extensions/gemm_bias.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa,
8888
// Call DPC++ GEMM_BIAS.
8989

9090
// Catch asynchronous exceptions.
91-
auto exception_handler = [](exception_list exceptions) {
91+
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
9292
for (std::exception_ptr const& e : exceptions) {
9393
try {
9494
std::rethrow_exception(e);
9595
}
9696
catch (exception const& e) {
9797
std::cout << "Caught asynchronous SYCL exception during GEMM_BIAS:\n"
9898
<< e.what() << std::endl
99-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
99+
<< "OpenCL status: " << e.what() << std::endl;
100100
}
101101
}
102102
};
@@ -142,7 +142,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa,
142142
catch (exception const& e) {
143143
std::cout << "Caught synchronous SYCL exception during GEMM_BIAS:\n"
144144
<< e.what() << std::endl
145-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
145+
<< "OpenCL status: " << e.what() << std::endl;
146146
}
147147

148148
catch (const oneapi::mkl::unimplemented& e) {

tests/unit_tests/blas/extensions/gemmt.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
6868
// Call DPC++ GEMMT.
6969

7070
// Catch asynchronous exceptions
71-
auto exception_handler = [](exception_list exceptions) {
71+
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
7272
for (std::exception_ptr const& e : exceptions) {
7373
try {
7474
std::rethrow_exception(e);
7575
}
7676
catch (exception const& e) {
7777
std::cout << "Caught asynchronous SYCL exception during GEMMT:\n"
7878
<< e.what() << std::endl
79-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
79+
<< "OpenCL status: " << e.what() << std::endl;
8080
}
8181
}
8282
};
@@ -121,7 +121,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
121121
catch (exception const& e) {
122122
std::cout << "Caught synchronous SYCL exception during GEMMT:\n"
123123
<< e.what() << std::endl
124-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
124+
<< "OpenCL status: " << e.what() << std::endl;
125125
}
126126

127127
catch (const oneapi::mkl::unimplemented& e) {

tests/unit_tests/blas/extensions/gemmt_usm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
4747
oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, int n, int k, int lda,
4848
int ldb, int ldc, fp alpha, fp beta) {
4949
// Catch asynchronous exceptions.
50-
auto exception_handler = [](exception_list exceptions) {
50+
cl::sycl::async_handler exception_handler = [](exception_list exceptions) {
5151
for (std::exception_ptr const& e : exceptions) {
5252
try {
5353
std::rethrow_exception(e);
5454
}
5555
catch (exception const& e) {
5656
std::cout << "Caught asynchronous SYCL exception during GEMMT:\n"
5757
<< e.what() << std::endl
58-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
58+
<< "OpenCL status: " << e.what() << std::endl;
5959
}
6060
}
6161
};
@@ -123,7 +123,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
123123
catch (exception const& e) {
124124
std::cout << "Caught synchronous SYCL exception during GEMMT:\n"
125125
<< e.what() << std::endl
126-
<< "OpenCL status: " << e.get_cl_code() << std::endl;
126+
<< "OpenCL status: " << e.what() << std::endl;
127127
}
128128

129129
catch (const oneapi::mkl::unimplemented& e) {

tests/unit_tests/blas/include/reference_blas_templates.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ template <typename fp>
224224
static void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
225225
const int *n, const int *k, const fp *alpha, const fp *a, const int *lda,
226226
const fp *b, const int *ldb, const fp *beta, fp *c, const int *ldc);
227-
227+
#ifdef NOT_HIPSYCL
228228
template <>
229229
void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
230230
const int *n, const int *k, const half *alpha, const half *a, const int *lda,
@@ -255,7 +255,7 @@ void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, c
255255
oneapi::mkl::aligned_free(bf);
256256
oneapi::mkl::aligned_free(cf);
257257
}
258-
258+
#endif
259259
template <>
260260
void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
261261
const int *n, const int *k, const float *alpha, const float *a, const int *lda,
@@ -291,7 +291,7 @@ template <typename fpa, typename fpc>
291291
static void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
292292
const int *n, const int *k, const fpc *alpha, const fpa *a, const int *lda,
293293
const fpa *b, const int *ldb, const fpc *beta, fpc *c, const int *ldc);
294-
294+
#ifdef NOT_HIPSYCL
295295
template <>
296296
void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
297297
const int *n, const int *k, const float *alpha, const half *a, const int *lda,
@@ -314,7 +314,7 @@ void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, c
314314
oneapi::mkl::aligned_free(af);
315315
oneapi::mkl::aligned_free(bf);
316316
}
317-
317+
#endif
318318
template <typename fp>
319319
static void symm(CBLAS_LAYOUT layout, CBLAS_SIDE left_right, CBLAS_UPLO uplo, const int *m,
320320
const int *n, const fp *alpha, const fp *a, const int *lda, const fp *b,

tests/unit_tests/blas/include/test_common.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
#include <CL/sycl.hpp>
3030

3131
namespace std {
32+
#ifdef NOT_HIPSYCL
3233
static cl::sycl::half abs(cl::sycl::half v) {
3334
if (v < cl::sycl::half(0))
3435
return -v;
3536
else
3637
return v;
3738
}
39+
#endif
3840
} // namespace std
3941

4042
// Complex helpers.
@@ -140,12 +142,12 @@ template <>
140142
uint8_t rand_scalar() {
141143
return std::rand() % 128;
142144
}
143-
145+
#ifdef NOT_HIPSYCL
144146
template <>
145147
half rand_scalar() {
146148
return half(std::rand() % 32000) / half(32000) - half(0.5);
147149
}
148-
150+
#endif
149151
template <typename fp>
150152
static fp rand_scalar(int mag) {
151153
fp tmp = fp(mag) + fp(std::rand()) / fp(RAND_MAX) - fp(0.5);

0 commit comments

Comments
 (0)