1+ #pragma once
2+ #include < bkma_tools.hpp>
3+
4+ // //==============================================================================
5+ // class AdaptiveWg : public IAdvectorX {
6+ // protected:
7+ // using IAdvectorX::IAdvectorX;
8+
9+ // /* We should be able to query max_batchs to the API.
10+ // | x | y/z |
11+ // CUDA:| 2**31-1 | 2**16-1 |
12+ // HIP :| 2**32-1 | 2**32-1 |
13+ // L0 :| 2**32-1 | 2**32-1 | (compile with -fno-sycl-query-fit-in-int)
14+ // CPU : a lot */
15+ // const size_t max_batchs_x_ = 65536 - 1;
16+ // const size_t max_batchs_yz_ = 65536 - 1;
17+
18+ // BatchConfig1D dispatch_dim0_;
19+ // BatchConfig1D dispatch_dim2_;
20+ // WorkItemDispatch local_size_;
21+ // WorkGroupDispatch wg_dispatch_;
22+
23+ // public:
24+ // sycl::event operator()(sycl::queue &Q, real_t *fdist_dev,
25+ // const AdvectionSolver &solver) override;
26+
27+ // AdaptiveWg() = delete;
28+
29+ // AdaptiveWg(const AdvectionSolver &solver, sycl::queue q) {
30+ // const auto n0 = solver.params.n0;
31+ // const auto n1 = solver.params.n1;
32+ // const auto n2 = solver.params.n2;
33+
34+ // dispatch_dim0_ = init_1d_blocking(n0, max_batchs_x_);
35+ // dispatch_dim2_ = init_1d_blocking(n2, max_batchs_yz_);
36+
37+ // // SYCL query returns the size in bytes
38+ // auto max_elem_local_mem =
39+ // q.get_device().get_info<sycl::info::device::local_mem_size>() /
40+ // sizeof(real_t);
41+
42+ // local_size_.set_ideal_sizes(solver.params.pref_wg_size, n0, n1, n2);
43+ // local_size_.adjust_sizes_mem_limit(max_elem_local_mem, n1);
44+
45+ // wg_dispatch_.s0_ = solver.params.seq_size0;
46+ // wg_dispatch_.s2_ = solver.params.seq_size2;
47+
48+ // // TODO: this line is overriden inside the kernel!!! useless
49+ // // wg_dispatch_.set_num_work_groups(n0, n2, dispatch_dim0_.n_batch_,
50+ // // dispatch_dim2_.n_batch_,
51+ // // local_size_.w0_, local_size_.w2_);
52+ // }
53+ // };
54+
55+ // ==========================================
56+ // ==========================================
57+ template <MemorySpace MemType, class MySolver , BkmaImpl Impl>
58+ inline std::enable_if_t <Impl == BkmaImpl::AdaptiveWg, sycl::event>
59+ submit_kernels (sycl::queue &Q, span3d_t data, const MySolver &solver,
60+ const size_t b0_size, const size_t b0_offset,
61+ const size_t b2_size, const size_t b2_offset,
62+ const size_t orig_w0, const size_t w1, const size_t orig_w2,
63+ WorkGroupDispatch wg_dispatch,
64+ span3d_t global_scratch = span3d_t {}) {
65+
66+ const auto w0 = sycl::min (orig_w0, b0_size);
67+ const auto w2 = sycl::min (orig_w2, b2_size);
68+
69+ wg_dispatch.set_num_work_groups (b0_size, b2_size, 1 , 1 , w0, w2);
70+ auto const seq_size0 = wg_dispatch.s0_ ;
71+ auto const seq_size2 = wg_dispatch.s2_ ;
72+ auto const g0 = wg_dispatch.g0_ ;
73+ auto const g2 = wg_dispatch.g2_ ;
74+
75+ const sycl::range<3 > global_size (g0 * w0, w1, g2 * w2);
76+ const sycl::range<3 > local_size (w0, w1, w2);
77+
78+ auto n0 = data.extent (0 );
79+ auto n1 = data.extent (1 );
80+ auto n2 = data.extent (2 );
81+
82+ const auto window = solver.window ();
83+ const auto nw = n1 - (window-1 );
84+
85+ return Q.submit ([&](sycl::handler &cgh) {
86+ auto mallocator = [&]() {
87+ if constexpr (MemType == MemorySpace::Local) {
88+ sycl::range<3 > acc_range (w0, w2, nw);
89+ return MemAllocator<MemType>(acc_range, cgh);
90+ } else {
91+ extents_t ext (b0_size, n2, n1);
92+ return MemAllocator<MemType>(global_scratch);
93+ }
94+ }();
95+
96+ cgh.parallel_for (
97+ sycl::nd_range<3 >{global_size, local_size},
98+ [=](auto itm) {
99+ span3d_t scr (mallocator.get_pointer (),
100+ mallocator.get_extents ());
101+
102+ const auto i1 = itm.get_local_id (1 );
103+ const auto local_i0 = compute_index<MemType>(itm, 0 );
104+ const auto local_i2 = compute_index<MemType>(itm, 2 );
105+
106+ auto scratch_slice = std::experimental::submdspan (
107+ scr, local_i0, local_i2, std::experimental::full_extent);
108+
109+ const auto start_idx0 = b0_offset + itm.get_global_id (0 );
110+ const auto stop_idx0 = sycl::min (n0, start_idx0 + b0_size);
111+ for (size_t global_i0 = start_idx0; global_i0 < stop_idx0;
112+ global_i0 += g0 * w0) {
113+
114+ const auto start_idx2 = b2_offset + itm.get_global_id (2 );
115+ const auto stop_idx2 = sycl::min (n2, start_idx2 + b2_size);
116+ for (size_t global_i2 = start_idx2; global_i2 < stop_idx2;
117+ global_i2 += g2 * w2) {
118+
119+ auto data_slice = std::experimental::submdspan (
120+ data, global_i0, std::experimental::full_extent,
121+ global_i2);
122+
123+ for (int ii1 = i1; ii1 < n1; ii1 += w1) {
124+ auto const iw = ii1 - (window - 1 );
125+ if (iw >= 0 )
126+ scratch_slice (iw) = solver (
127+ data_slice, global_i0, ii1, global_i2);
128+ }
129+
130+ sycl::group_barrier (itm.get_group ());
131+
132+ for (int iw = i1; iw < nw; iw += w1) {
133+ data_slice (iw) = scratch_slice (iw);
134+ }
135+ } // end for ii2
136+ } // end for ii0
137+ } // end lambda in parallel_for
138+ ); // end parallel_for nd_range
139+ }); // end Q.submit
140+ } // end submit_kernels
0 commit comments