Skip to content

Commit 12edf97

Browse files
committed
[cublas] introduce onemkl_cublas_host_task
[cublas] move dpc++ internal headers into cublas_task.hpp
1 parent e8e3dab commit 12edf97

File tree

8 files changed

+365
-328
lines changed

8 files changed

+365
-328
lines changed

src/blas/backends/cublas/cublas_batch.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
* limitations under the License.
1717
*
1818
**************************************************************************/
19-
#include <CL/sycl/detail/pi.hpp>
2019
#include "cublas_helper.hpp"
21-
#include "cublas_scope_handle.hpp"
20+
#include "cublas_task.hpp"
21+
2222
#include "oneapi/mkl/exceptions.hpp"
2323
#include "oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp"
2424

@@ -42,12 +42,12 @@ inline void gemm_batch(Func func, cl::sycl::queue &queue, transpose transa, tran
4242
auto a_acc = a.template get_access<cl::sycl::access::mode::read>(cgh);
4343
auto b_acc = b.template get_access<cl::sycl::access::mode::read>(cgh);
4444
auto c_acc = c.template get_access<cl::sycl::access::mode::read_write>(cgh);
45-
cgh.interop_task([=](cl::sycl::interop_handler ih) {
46-
auto sc = CublasScopedContextHandler(queue);
45+
onemkl_cublas_host_task(cgh, queue,[=](CublasScopedContextHandler sc) {
4746
auto handle = sc.get_handle(queue);
48-
auto a_ = sc.get_mem<cuDataType *>(ih, a_acc);
49-
auto b_ = sc.get_mem<cuDataType *>(ih, b_acc);
50-
auto c_ = sc.get_mem<cuDataType *>(ih, c_acc);
47+
48+
auto a_ = sc.get_mem<cuDataType *>(a_acc);
49+
auto b_ = sc.get_mem<cuDataType *>(b_acc);
50+
auto c_ = sc.get_mem<cuDataType *>(c_acc);
5151
cublasStatus_t err;
5252
CUBLAS_ERROR_FUNC(func, err, handle, get_cublas_operation(transa),
5353
get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, a_, lda,
@@ -122,9 +122,9 @@ inline cl::sycl::event gemm_batch(Func func, cl::sycl::queue &queue, transpose t
122122
for (int64_t i = 0; i < num_events; i++) {
123123
cgh.depends_on(dependencies[i]);
124124
}
125-
cgh.interop_task([=](cl::sycl::interop_handler ih) {
126-
auto sc = CublasScopedContextHandler(queue);
125+
onemkl_cublas_host_task(cgh, queue,[=](CublasScopedContextHandler sc) {
127126
auto handle = sc.get_handle(queue);
127+
128128
auto a_ = reinterpret_cast<const cuDataType *>(a);
129129
auto b_ = reinterpret_cast<const cuDataType *>(b);
130130
auto c_ = reinterpret_cast<cuDataType *>(c);
@@ -170,9 +170,9 @@ inline cl::sycl::event gemm_batch(Func func, cl::sycl::queue &queue, transpose *
170170
for (int64_t i = 0; i < num_events; i++) {
171171
cgh.depends_on(dependencies[i]);
172172
}
173-
cgh.interop_task([=](cl::sycl::interop_handler ih) {
174-
auto sc = CublasScopedContextHandler(queue);
173+
onemkl_cublas_host_task(cgh, queue,[=](CublasScopedContextHandler sc) {
175174
auto handle = sc.get_handle(queue);
175+
176176
int64_t offset = 0;
177177
cublasStatus_t err;
178178
for (int64_t i = 0; i < group_count; i++) {

src/blas/backends/cublas/cublas_extensions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
* limitations under the License.
1717
*
1818
**************************************************************************/
19-
#include <CL/sycl/detail/pi.hpp>
2019
#include "cublas_helper.hpp"
21-
#include "cublas_scope_handle.hpp"
20+
#include "cublas_task.hpp"
21+
2222
#include "oneapi/mkl/exceptions.hpp"
2323
#include "oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp"
2424

src/blas/backends/cublas/cublas_level1.cpp

Lines changed: 97 additions & 93 deletions
Large diffs are not rendered by default.

src/blas/backends/cublas/cublas_level2.cpp

Lines changed: 153 additions & 153 deletions
Large diffs are not rendered by default.

src/blas/backends/cublas/cublas_level3.cpp

Lines changed: 65 additions & 66 deletions
Large diffs are not rendered by default.

src/blas/backends/cublas/cublas_scope_handle.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ cublas_handle::~cublas_handle() noexcept(false) {
4848
*/
4949
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};
5050

51-
CublasScopedContextHandler::CublasScopedContextHandler(cl::sycl::queue queue) {
51+
CublasScopedContextHandler::CublasScopedContextHandler(cl::sycl::queue queue, cl::sycl::interop_handler& ih): ih(ih){
5252
placedContext_ = queue.get_context();
5353
auto device = queue.get_device();
5454
auto desired = cl::sycl::get_native<cl::sycl::backend::cuda>(placedContext_);

src/blas/backends/cublas/cublas_scope_handle.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,13 @@ class CublasScopedContextHandler {
6868
CUcontext original_;
6969
cl::sycl::context placedContext_;
7070
bool needToRecover_;
71+
cl::sycl::interop_handler& ih;
7172
static thread_local cublas_handle handle_helper;
7273
CUstream get_stream(const cl::sycl::queue &queue);
7374
cl::sycl::context get_context(const cl::sycl::queue &queue);
7475

7576
public:
76-
CublasScopedContextHandler(cl::sycl::queue queue);
77+
CublasScopedContextHandler(cl::sycl::queue queue, cl::sycl::interop_handler& ih);
7778

7879
~CublasScopedContextHandler() noexcept(false);
7980
/**
@@ -87,7 +88,7 @@ class CublasScopedContextHandler {
8788
// This is a work-around function for reinterpret_casting the memory. This
8889
// will be fixed when SYCL-2020 has been implemented for Pi backend.
8990
template <typename T, typename U>
90-
inline T get_mem(cl::sycl::interop_handler ih, U acc) {
91+
inline T get_mem(U acc) {
9192
CUdeviceptr cudaPtr = ih.get_mem<cl::sycl::backend::cuda>(acc);
9293
return reinterpret_cast<T>(cudaPtr);
9394
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#ifndef _MKL_BLAS_CUBLAS_TASK_HPP_
2+
#define _MKL_BLAS_CUBLAS_TASK_HPP_
3+
#include <cublas_v2.h>
4+
#include <cuda.h>
5+
#include <complex>
6+
#include <CL/sycl.hpp>
7+
#include "oneapi/mkl/types.hpp"
8+
#include "cublas_scope_handle.hpp"
9+
#include <CL/sycl/detail/pi.hpp>
10+
11+
namespace oneapi {
12+
namespace mkl {
13+
namespace blas {
14+
namespace cublas {
15+
16+
template <typename H, typename F>
17+
static inline auto host_task_internal(H &cgh, cl::sycl::queue queue, F f) -> decltype(cgh.interop_task(f)) {
18+
cgh.interop_task([f, queue](cl::sycl::interop_handler ih){
19+
auto sc = CublasScopedContextHandler(queue, ih);
20+
f(sc);
21+
});
22+
}
23+
24+
template <typename H, typename F>
25+
static inline void onemkl_cublas_host_task(H &cgh, cl::sycl::queue queue, F f) {
26+
(void)host_task_internal(cgh, queue, f);
27+
}
28+
29+
} // namespace cublas
30+
} // namespace blas
31+
} // namespace mkl
32+
} // namespace oneapi
33+
#endif // _MKL_BLAS_CUBLAS_TASK_HPP_

0 commit comments

Comments
 (0)