Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/batch/axpy_batch_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during AXPY_BATCH:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -186,7 +186,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during AXPY_BATCH:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/batch/gemm_batch_stride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -181,7 +181,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -208,7 +208,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during GEMM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/batch/gemm_batch_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -263,7 +263,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during GEMM_BATCH:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/batch/trsm_batch_stride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ int test(device *dev, oneapi::mkl::layout layout) {
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during TRSM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -173,7 +173,7 @@ int test(device *dev, oneapi::mkl::layout layout) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during TRSM_BATCH_STRIDE:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/extensions/gemm_bias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa,
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during GEMM_BIAS:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -142,7 +142,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa,
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during GEMM_BIAS:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/extensions/gemmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during GEMMT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -121,7 +121,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during GEMMT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/extensions/gemmt_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during GEMMT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -123,7 +123,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower,
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during GEMMT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
12 changes: 4 additions & 8 deletions tests/unit_tests/blas/include/reference_blas_templates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,10 @@ template <typename fp>
static void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const fp *alpha, const fp *a, const int *lda,
const fp *b, const int *ldb, const fp *beta, fp *c, const int *ldc);

template <>
void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const half *alpha, const half *a, const int *lda,
const half *b, const int *ldb, const half *beta, half *c, const int *ldc) {
const int *n, const int *k, const cl::sycl::half *alpha, const cl::sycl::half *a, const int *lda,
const cl::sycl::half *b, const int *ldb, const cl::sycl::half *beta, cl::sycl::half *c, const int *ldc) {
// Not supported in NETLIB. SGEMM is used as reference.
int sizea, sizeb, sizec;
const float alphaf = *alpha;
Expand All @@ -255,7 +254,6 @@ void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, c
oneapi::mkl::aligned_free(bf);
oneapi::mkl::aligned_free(cf);
}

template <>
void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const float *alpha, const float *a, const int *lda,
Expand Down Expand Up @@ -291,11 +289,10 @@ template <typename fpa, typename fpc>
static void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const fpc *alpha, const fpa *a, const int *lda,
const fpa *b, const int *ldb, const fpc *beta, fpc *c, const int *ldc);

template <>
void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, const int *m,
const int *n, const int *k, const float *alpha, const half *a, const int *lda,
const half *b, const int *ldb, const float *beta, float *c, const int *ldc) {
const int *n, const int *k, const float *alpha, const cl::sycl::half *a, const int *lda,
const cl::sycl::half *b, const int *ldb, const float *beta, float *c, const int *ldc) {
// Not supported in NETLIB. SGEMM is used as reference.
int sizea, sizeb;
if (layout == CblasColMajor) {
Expand All @@ -314,7 +311,6 @@ void gemm(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transa, CBLAS_TRANSPOSE transb, c
oneapi::mkl::aligned_free(af);
oneapi::mkl::aligned_free(bf);
}

template <typename fp>
static void symm(CBLAS_LAYOUT layout, CBLAS_SIDE left_right, CBLAS_UPLO uplo, const int *m,
const int *n, const fp *alpha, const fp *a, const int *lda, const fp *b,
Expand Down
11 changes: 7 additions & 4 deletions tests/unit_tests/blas/include/test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@
#include <type_traits>

#include <CL/sycl.hpp>
#include "oneapi/mkl/detail/config.hpp"

namespace std {
//#ifdef ENABLE_HALF_ROUTINES
static cl::sycl::half abs(cl::sycl::half v) {
if (v < cl::sycl::half(0))
return -v;
else
return v;
}
//#endif
} // namespace std

// Complex helpers.
Expand Down Expand Up @@ -140,12 +143,12 @@ template <>
uint8_t rand_scalar() {
return std::rand() % 128;
}

//#ifdef ENABLE_HALF_ROUTINES
template <>
half rand_scalar() {
return half(std::rand() % 32000) / half(32000) - half(0.5);
cl::sycl::half rand_scalar() {
return cl::sycl::half(std::rand() % 32000) / cl::sycl::half(32000) - cl::sycl::half(0.5);
}

//#endif
template <typename fp>
static fp rand_scalar(int mag) {
fp tmp = fp(mag) + fp(std::rand()) / fp(RAND_MAX) - fp(0.5);
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/level1/asum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) {
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during ASUM:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -103,7 +103,7 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) {
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during ASUM:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/level1/asum_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) {
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during ASUM:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -111,7 +111,7 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) {
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during ASUM:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/level1/axpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during AXPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -108,7 +108,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during AXPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/level1/axpy_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during AXPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -112,7 +112,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during AXPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/level1/copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during COPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -106,7 +106,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during COPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/level1/copy_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during COPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -111,7 +111,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during COPY:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/level1/dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during DOT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -107,7 +107,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during DOT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/level1/dot_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const& e) {
std::cout << "Caught asynchronous SYCL exception during DOT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -111,7 +111,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const& e) {
std::cout << "Caught synchronous SYCL exception during DOT:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented& e) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/blas/level1/dotc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const &e) {
std::cout << "Caught asynchronous SYCL exception during DOTC:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}
}
};
Expand Down Expand Up @@ -109,7 +109,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) {
catch (exception const &e) {
std::cout << "Caught synchronous SYCL exception during DOTC:\n"
<< e.what() << std::endl
<< "OpenCL status: " << e.get_cl_code() << std::endl;
<< "OpenCL status: " << e.what() << std::endl;
}

catch (const oneapi::mkl::unimplemented &e) {
Expand Down
Loading