Skip to content

Commit ea48e66

Browse files
committed
begin refactoring
1 parent 009c4dc commit ea48e66

File tree

10 files changed

+582
-27
lines changed

10 files changed

+582
-27
lines changed

src/CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
add_subdirectory(config)
2-
add_subdirectory(core)
2+
# add_subdirectory(core)
33

44
function(add_bkma_executable name)
55
add_executable(${name} ${name}.cpp)
66

77
target_link_libraries(${name}
88
PUBLIC
9-
bkma::config
10-
bkma::core)
9+
bkma::config)
10+
# bkma::core)
1111

1212
target_include_directories(${name}
1313
PUBLIC
1414
${CMAKE_SOURCE_DIR}/src/tools
15-
${CMAKE_SOURCE_DIR}/src/core
15+
${CMAKE_SOURCE_DIR}/src/core/bkma
1616
${CMAKE_SOURCE_DIR}/src/solvers
1717
${CMAKE_SOURCE_DIR}/src
1818
)
@@ -27,5 +27,5 @@ function(add_bkma_executable name)
2727
endfunction()
2828

2929
# Add executables
30-
add_bkma_executable(advection)
30+
# add_bkma_executable(advection)
3131
add_bkma_executable(conv1d)

src/conv1d.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ main(int argc, char **argv) {
104104
const auto channel_out = params.channel_out;
105105
const auto length = params.length;
106106

107-
const auto n0 = params.n0; // n
107+
const auto n0 = params.n0; // n
108108
const auto n1 = params.n1; // l*oc
109-
const auto n2 = params.n2; // n
109+
const auto n2 = params.n2; // n
110110
const auto k = params.k;
111111

112112
span3d_t data(sycl::malloc_shared<real_t>(n0 * n1 * n2, Q), n0,
@@ -164,10 +164,13 @@ main(int argc, char **argv) {
164164

165165
/* Warmup to JIT model */
166166
for (int i = 0; i < 3; ++i)
167-
bkma_run(Q, warmup_data, solver, optim_params).wait();
167+
bkma_run<ConvSolver, BkmaImpl::AdaptiveWg>(Q, warmup_data, solver,
168+
optim_params)
169+
.wait();
168170

169171
auto start = std::chrono::high_resolution_clock::now();
170-
bkma_run(Q, data, solver, optim_params).wait();
172+
bkma_run<ConvSolver, BkmaImpl::AdaptiveWg>(Q, data, solver, optim_params)
173+
.wait();
171174
auto end = std::chrono::high_resolution_clock::now();
172175
const std::chrono::duration<double> elapsed_seconds = end - start;
173176

src/core/bkma/AdaptiveWg.hpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#pragma once
2+
#include <bkma_tools.hpp>
3+
4+
// //==============================================================================
5+
// class AdaptiveWg : public IAdvectorX {
6+
// protected:
7+
// using IAdvectorX::IAdvectorX;
8+
9+
// /* We should be able to query max_batchs to the API.
10+
// | x | y/z |
11+
// CUDA:| 2**31-1 | 2**16-1 |
12+
// HIP :| 2**32-1 | 2**32-1 |
13+
// L0 :| 2**32-1 | 2**32-1 | (compile with -fno-sycl-query-fit-in-int)
14+
// CPU : a lot */
15+
// const size_t max_batchs_x_ = 65536 - 1;
16+
// const size_t max_batchs_yz_ = 65536 - 1;
17+
18+
// BatchConfig1D dispatch_dim0_;
19+
// BatchConfig1D dispatch_dim2_;
20+
// WorkItemDispatch local_size_;
21+
// WorkGroupDispatch wg_dispatch_;
22+
23+
// public:
24+
// sycl::event operator()(sycl::queue &Q, real_t *fdist_dev,
25+
// const AdvectionSolver &solver) override;
26+
27+
// AdaptiveWg() = delete;
28+
29+
// AdaptiveWg(const AdvectionSolver &solver, sycl::queue q) {
30+
// const auto n0 = solver.params.n0;
31+
// const auto n1 = solver.params.n1;
32+
// const auto n2 = solver.params.n2;
33+
34+
// dispatch_dim0_ = init_1d_blocking(n0, max_batchs_x_);
35+
// dispatch_dim2_ = init_1d_blocking(n2, max_batchs_yz_);
36+
37+
// // SYCL query returns the size in bytes
38+
// auto max_elem_local_mem =
39+
// q.get_device().get_info<sycl::info::device::local_mem_size>() /
40+
// sizeof(real_t);
41+
42+
// local_size_.set_ideal_sizes(solver.params.pref_wg_size, n0, n1, n2);
43+
// local_size_.adjust_sizes_mem_limit(max_elem_local_mem, n1);
44+
45+
// wg_dispatch_.s0_ = solver.params.seq_size0;
46+
// wg_dispatch_.s2_ = solver.params.seq_size2;
47+
48+
// // TODO: this line is overriden inside the kernel!!! useless
49+
// // wg_dispatch_.set_num_work_groups(n0, n2, dispatch_dim0_.n_batch_,
50+
// // dispatch_dim2_.n_batch_,
51+
// // local_size_.w0_, local_size_.w2_);
52+
// }
53+
// };
54+
55+
// ==========================================
56+
// ==========================================
57+
template <MemorySpace MemType, class MySolver, BkmaImpl Impl>
58+
inline std::enable_if_t<Impl == BkmaImpl::AdaptiveWg, sycl::event>
59+
submit_kernels(sycl::queue &Q, span3d_t data, const MySolver &solver,
60+
const size_t b0_size, const size_t b0_offset,
61+
const size_t b2_size, const size_t b2_offset,
62+
const size_t orig_w0, const size_t w1, const size_t orig_w2,
63+
WorkGroupDispatch wg_dispatch,
64+
span3d_t global_scratch = span3d_t{}) {
65+
66+
const auto w0 = sycl::min(orig_w0, b0_size);
67+
const auto w2 = sycl::min(orig_w2, b2_size);
68+
69+
wg_dispatch.set_num_work_groups(b0_size, b2_size, 1, 1, w0, w2);
70+
auto const seq_size0 = wg_dispatch.s0_;
71+
auto const seq_size2 = wg_dispatch.s2_;
72+
auto const g0 = wg_dispatch.g0_;
73+
auto const g2 = wg_dispatch.g2_;
74+
75+
const sycl::range<3> global_size(g0 * w0, w1, g2 * w2);
76+
const sycl::range<3> local_size(w0, w1, w2);
77+
78+
auto n0 = data.extent(0);
79+
auto n1 = data.extent(1);
80+
auto n2 = data.extent(2);
81+
82+
const auto window = solver.window();
83+
const auto nw = n1 - (window-1);
84+
85+
return Q.submit([&](sycl::handler &cgh) {
86+
auto mallocator = [&]() {
87+
if constexpr (MemType == MemorySpace::Local) {
88+
sycl::range<3> acc_range(w0, w2, nw);
89+
return MemAllocator<MemType>(acc_range, cgh);
90+
} else {
91+
extents_t ext(b0_size, n2, n1);
92+
return MemAllocator<MemType>(global_scratch);
93+
}
94+
}();
95+
96+
cgh.parallel_for(
97+
sycl::nd_range<3>{global_size, local_size},
98+
[=](auto itm) {
99+
span3d_t scr(mallocator.get_pointer(),
100+
mallocator.get_extents());
101+
102+
const auto i1 = itm.get_local_id(1);
103+
const auto local_i0 = compute_index<MemType>(itm, 0);
104+
const auto local_i2 = compute_index<MemType>(itm, 2);
105+
106+
auto scratch_slice = std::experimental::submdspan(
107+
scr, local_i0, local_i2, std::experimental::full_extent);
108+
109+
const auto start_idx0 = b0_offset + itm.get_global_id(0);
110+
const auto stop_idx0 = sycl::min(n0, start_idx0 + b0_size);
111+
for (size_t global_i0 = start_idx0; global_i0 < stop_idx0;
112+
global_i0 += g0 * w0) {
113+
114+
const auto start_idx2 = b2_offset + itm.get_global_id(2);
115+
const auto stop_idx2 = sycl::min(n2, start_idx2 + b2_size);
116+
for (size_t global_i2 = start_idx2; global_i2 < stop_idx2;
117+
global_i2 += g2 * w2) {
118+
119+
auto data_slice = std::experimental::submdspan(
120+
data, global_i0, std::experimental::full_extent,
121+
global_i2);
122+
123+
for (int ii1 = i1; ii1 < n1; ii1 += w1) {
124+
auto const iw = ii1 - (window - 1);
125+
if(iw >= 0)
126+
scratch_slice(iw) = solver(
127+
data_slice, global_i0, ii1, global_i2);
128+
}
129+
130+
sycl::group_barrier(itm.get_group());
131+
132+
for (int iw = i1; iw < nw; iw += w1) {
133+
data_slice(iw) = scratch_slice(iw);
134+
}
135+
} // end for ii2
136+
} // end for ii0
137+
} // end lambda in parallel_for
138+
); // end parallel_for nd_range
139+
}); // end Q.submit
140+
} // end submit_kernels

src/core/bkma/BasicRange.hpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#pragma once
2+
#include <bkma_tools.hpp>
3+
// class BasicRange : public IAdvectorX {
4+
// protected:
5+
// sycl::queue q_;
6+
// real_t *ftmp_;
7+
8+
// public:
9+
// BasicRange(const AdvectionSolver &solver, sycl::queue q) {
10+
// const auto n0 = solver.params.n0;
11+
// const auto n1 = solver.params.n1;
12+
// const auto n2 = solver.params.n2;
13+
14+
// ftmp_ = sycl::malloc_device<real_t>(n0 * n1 * n2, q_);
15+
// q_.wait();
16+
// }
17+
18+
// ~BasicRange() {
19+
// sycl::free(ftmp_, q_);
20+
// q_.wait();
21+
// }
22+
23+
// sycl::event operator()(sycl::queue &Q, real_t *data,
24+
// const AdvectionSolver &solver) override;
25+
// };
26+
27+
template <MemorySpace MemType, class MySolver, BkmaImpl Impl>
28+
inline std::enable_if_t<Impl == BkmaImpl::BasicRange, sycl::event>
29+
submit_kernels(sycl::queue &Q, span3d_t data, const MySolver &solver,
30+
const size_t b0_size, const size_t b0_offset,
31+
const size_t b2_size, const size_t b2_offset,
32+
const size_t orig_w0, const size_t w1, const size_t orig_w2,
33+
WorkGroupDispatch wg_dispatch, span3d_t global_scratch) {
34+
35+
static_assert(
36+
!(MemType == MemorySpace::Local && BkmaImpl::BasicRange == Impl),
37+
"BasicRange is not supported with MemorySpace::Local");
38+
39+
auto n0 = data.extent(0);
40+
auto n1 = data.extent(1);
41+
auto n2 = data.extent(2);
42+
43+
sycl::range r3d(n0, n1, n2);
44+
45+
Q.submit([&](sycl::handler &cgh) {
46+
cgh.parallel_for(r3d, [=](sycl::id<3> itm) {
47+
const int i1 = itm[1];
48+
const int i0 = itm[0];
49+
const int i2 = itm[2];
50+
51+
global_scratch(i0, i1, i2) =
52+
solver(std::experimental::submdspan(
53+
data, i0, std::experimental::full_extent, i2),
54+
i0, i1, i2);
55+
// barrier
56+
}); // end parallel_for
57+
}); // end Q.submit
58+
Q.wait();
59+
// copy
60+
return Q.submit([&](sycl::handler &cgh) {
61+
cgh.parallel_for(r3d, [=](sycl::id<3> itm) {
62+
const int i1 = itm[1];
63+
const int i0 = itm[0];
64+
const int i2 = itm[2];
65+
data(i0, i1, i2) = global_scratch(i0, i1, i2);
66+
// barrier
67+
}); // end parallel_for
68+
}); // end Q.submit
69+
}

src/core/bkma/MemorySpace.hpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#pragma once
2+
#include <sycl/sycl.hpp>
3+
4+
#ifdef SYCL_IMPLEMENTATION_ONEAPI
5+
#define GET_POINTER get_multi_ptr<sycl::access::decorated::no>().get
6+
#else
7+
#define GET_POINTER get_pointer
8+
#endif
9+
10+
//==============================================================================
11+
//==============================================================================
12+
enum class MemorySpace { Local, Global };
13+
14+
template <MemorySpace MemType> struct MemAllocator;
15+
16+
template <MemorySpace MemType>
17+
static inline size_t compute_index(const sycl::nd_item<3> &itm,
18+
unsigned short dim);
19+
20+
// ==========================================
21+
// ==========================================
22+
/* Local memory functions */
23+
template <> struct MemAllocator<MemorySpace::Local> {
24+
local_acc acc_;
25+
extents_t extents_;
26+
27+
[[nodiscard]] MemAllocator(sycl::range<3> range, sycl::handler &cgh)
28+
: acc_(range, cgh), extents_(range.get(0), range.get(1), range.get(2)) {
29+
}
30+
[[nodiscard]] inline auto get_pointer() const { return acc_.GET_POINTER(); }
31+
32+
[[nodiscard]] inline auto get_extents() const { return extents_; }
33+
};
34+
35+
template <>
36+
inline size_t
37+
compute_index<MemorySpace::Local>(const sycl::nd_item<3> &itm,
38+
unsigned short dim) {
39+
return itm.get_local_id(dim);
40+
}
41+
42+
// ==========================================
43+
// ==========================================
44+
/* Global memory functions */
45+
template <> struct MemAllocator<MemorySpace::Global> {
46+
span3d_t data_;
47+
48+
[[nodiscard]] MemAllocator(span3d_t global_scratch_)
49+
: data_(global_scratch_){};
50+
51+
[[nodiscard]] inline size_t compute_index(const sycl::nd_item<3> &itm,
52+
unsigned short dim) {
53+
return itm.get_global_id(dim);
54+
}
55+
56+
[[nodiscard]] inline auto get_pointer() const {
57+
return data_.data_handle();
58+
}
59+
60+
[[nodiscard]] inline auto get_extents() const {
61+
return extents_t{data_.extent(0), data_.extent(1), data_.extent(2)};
62+
}
63+
};
64+
65+
template <>
66+
inline size_t
67+
compute_index<MemorySpace::Global>(const sycl::nd_item<3> &itm,
68+
unsigned short dim) {
69+
return itm.get_global_id(dim);
70+
}

src/core/bkma/NDRange.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#pragma once
2+
#include <bkma_tools.hpp>
3+
4+
template <MemorySpace MemType, class MySolver, BkmaImpl Impl>
5+
inline std::enable_if_t<Impl == BkmaImpl::NDRange, sycl::event>
6+
submit_kernels(sycl::queue &Q, span3d_t data, const MySolver &solver,
7+
const size_t b0_size, const size_t b0_offset,
8+
const size_t b2_size, const size_t b2_offset,
9+
const size_t orig_w0, const size_t w1, const size_t orig_w2,
10+
WorkGroupDispatch wg_dispatch,
11+
span3d_t global_scratch = span3d_t{}) {
12+
13+
const auto n0 = data.extent(0);
14+
const auto n1 = data.extent(1);
15+
const auto n2 = data.extent(2);
16+
17+
const sycl::range global_size{n0, n1, n2};
18+
const sycl::range local_size{1, n1, 1};
19+
20+
return Q.submit([&](sycl::handler &cgh) {
21+
sycl::local_accessor<real_t, 1> slice_ftmp(sycl::range<1>(n1), cgh);
22+
23+
cgh.parallel_for(sycl::nd_range<3>{global_size, local_size},
24+
[=](auto itm) {
25+
const int i1 = itm.get_local_id(1);
26+
const int i0 = itm.get_global_id(0);
27+
const int i2 = itm.get_global_id(2);
28+
29+
auto slice = std::experimental::submdspan(
30+
data, i0, std::experimental::full_extent, i2);
31+
32+
slice_ftmp[i1] = solver(slice, i0, i1, i2);
33+
34+
sycl::group_barrier(itm.get_group());
35+
36+
slice(i1) = slice_ftmp[i1];
37+
} // end lambda in parallel_for
38+
); // end parallel_for nd_range
39+
}); // end Q.submit
40+
}

src/core/bkma/bkma.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
#include <BasicRange.hpp>
3+
#include <NDRange.hpp>
4+
#include <AdaptiveWg.hpp>
5+
#include <bkma_tools.hpp>
6+
#include <MemorySpace.hpp>
7+
#include <bkma_run.hpp>

0 commit comments

Comments
 (0)