Skip to content

Commit 4549117

Browse files
Optimization of custom_reduce_over_group function.
The function used to perform custom reduction in a single work-item (leader of the work-group sequentially). It now does so cooperatively few iterations, and processes remaining non-reduced elements sequentially in the leading work-item. The custom_reduce_over_group got sped up about a factor of 3x. The following now shows timing of the reduction kernel ``` unitrace -d -v -i 20 python -c "import dpctl.tensor as dpt; dpt.min(dpt.ones(10**7, dtype=dpt.float32)).sycl_queue.wait()" ``` or par (less that 10%) slower than the int32 kernel, which uses built-in sycl::reduce_over_group: ``` unitrace -d -v -i 20 python -c "import dpctl.tensor as dpt; dpt.min(dpt.ones(10**7, dtype=dpt.int32)).sycl_queue.wait()" ```
1 parent 0bcd635 commit 4549117

File tree

1 file changed

+59
-5
lines changed

1 file changed

+59
-5
lines changed

dpctl/tensor/libtensor/include/utils/sycl_utils.hpp

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,81 @@ size_t choose_workgroup_size(const size_t nelems,
132132
return wg;
133133
}
134134

135+
namespace
136+
{
137+
138+
template <typename LocAccT, typename OpT>
139+
void _fold(LocAccT &local_mem_acc,
140+
const std::uint32_t lid,
141+
const std::uint32_t cutoff,
142+
const std::uint32_t step,
143+
const OpT &op)
144+
{
145+
if (lid < cutoff) {
146+
local_mem_acc[lid] = op(local_mem_acc[lid], local_mem_acc[step + lid]);
147+
}
148+
}
149+
150+
template <typename LocAccT, typename OpT>
151+
void _fold(LocAccT &local_mem_acc,
152+
const std::uint32_t lid,
153+
const std::uint32_t step,
154+
const OpT &op)
155+
{
156+
if (lid < step) {
157+
local_mem_acc[lid] = op(local_mem_acc[lid], local_mem_acc[step + lid]);
158+
}
159+
}
160+
161+
} // namespace
162+
135163
template <typename T, typename GroupT, typename LocAccT, typename OpT>
136164
T custom_reduce_over_group(const GroupT &wg,
137165
LocAccT local_mem_acc,
138166
const T &local_val,
139167
const OpT &op)
140168
{
141-
size_t wgs = wg.get_local_linear_range();
142-
local_mem_acc[wg.get_local_linear_id()] = local_val;
169+
const std::uint32_t wgs = wg.get_local_linear_range();
170+
const std::uint32_t lid = wg.get_local_linear_id();
143171

172+
local_mem_acc[lid] = local_val;
144173
sycl::group_barrier(wg, sycl::memory_scope::work_group);
145174

175+
std::uint32_t n_witems = wgs;
176+
if (wgs & (wgs - 1)) {
177+
// wgs is not a power of 2
178+
#pragma unroll
179+
for (std::uint32_t sz = 1024; sz >= 32; sz >>= 1) {
180+
if (n_witems >= sz) {
181+
const std::uint32_t n_witems_ = (n_witems + 1) >> 1;
182+
_fold(local_mem_acc, lid, n_witems - n_witems_, n_witems_, op);
183+
sycl::group_barrier(wg, sycl::memory_scope::work_group);
184+
n_witems = n_witems_;
185+
}
186+
}
187+
}
188+
else {
189+
// wgs is a power of 2
190+
#pragma unroll
191+
for (std::uint32_t sz = 1024; sz >= 32; sz >>= 1) {
192+
if (n_witems >= sz) {
193+
n_witems = (n_witems + 1) >> 1;
194+
_fold(local_mem_acc, lid, n_witems, op);
195+
sycl::group_barrier(wg, sycl::memory_scope::work_group);
196+
}
197+
}
198+
}
199+
146200
T red_val_over_wg = local_mem_acc[0];
147201
if (wg.leader()) {
148-
for (size_t i = 1; i < wgs; ++i) {
202+
for (std::uint32_t i = 1; i < n_witems; ++i) {
149203
red_val_over_wg = op(red_val_over_wg, local_mem_acc[i]);
150204
}
151205
}
152206

153-
sycl::group_barrier(wg, sycl::memory_scope::work_group);
207+
// sycl::group_barrier(wg, sycl::memory_scope::work_group);
154208

155-
return sycl::group_broadcast(wg, red_val_over_wg);
209+
return sycl::group_broadcast(wg, red_val_over_wg, 0);
156210
}
157211

158212
template <typename T, typename GroupT, typename LocAccT, typename OpT>

0 commit comments

Comments
 (0)