Skip to content

Commit 183042a

Browse files
committed
Address more feedback
1 parent d41fee7 commit 183042a

File tree

1 file changed

+39
-37
lines changed

1 file changed

+39
-37
lines changed

sycl/test-e2e/FreeFunctionKernels/free_function_kernel_local_memory.cpp

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
// either by way of the work group scratch memory extension or the work group
1111
// static memory extension.
1212

13-
#include <sycl/ext/oneapi/work_group_static.hpp>
14-
1513
#include "helpers.hpp"
14+
1615
#include <cassert>
1716
#include <sycl/ext/oneapi/experimental/enqueue_functions.hpp>
1817
#include <sycl/ext/oneapi/free_function_queries.hpp>
18+
#include <sycl/ext/oneapi/work_group_static.hpp>
1919
#include <sycl/group_barrier.hpp>
2020
#include <sycl/usm.hpp>
2121

@@ -25,36 +25,38 @@ namespace syclexp = sycl::ext::oneapi::experimental;
2525
constexpr int SIZE = 16;
2626

2727
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
28-
void scratch_kernel(float *src, float *dst) {
28+
void scratchKernel(float *src, float *dst) {
2929
size_t lid = syclext::this_work_item::get_nd_item<1>().get_local_linear_id();
30-
float *local_mem = (float *)syclexp::get_work_group_scratch_memory();
31-
local_mem[lid] = 2 * src[lid];
32-
dst[lid] = local_mem[lid];
30+
float *localMem =
31+
reinterpret_cast<float *>(syclexp::get_work_group_scratch_memory());
32+
localMem[lid] = 2 * src[lid];
33+
dst[lid] = localMem[lid];
3334
}
3435

3536
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
36-
void static_kernel(float *src, float *dst) {
37+
void staticKernel(float *src, float *dst) {
3738
sycl::nd_item<1> item = syclext::this_work_item::get_nd_item<1>();
3839
size_t lid = item.get_local_linear_id();
39-
syclexp::work_group_static<float[SIZE]> local_mem;
40-
local_mem[lid] = src[lid] * src[lid];
40+
syclexp::work_group_static<float[SIZE]> localMem;
41+
localMem[lid] = src[lid] * src[lid];
4142
sycl::group_barrier(item.get_group());
4243
if (item.get_group().leader()) { // Check that memory is indeed shared between
43-
// the work group
44+
// the work group.
4445
for (int i = 0; i < SIZE; ++i)
45-
assert(local_mem[i] == src[i] * src[i]);
46+
assert(localMem[i] == src[i] * src[i]);
4647
}
47-
dst[lid] = local_mem[lid];
48+
dst[lid] = localMem[lid];
4849
}
4950

5051
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
51-
void scratch_static_kernel(float *src, float *dst) {
52+
void scratchStaticKernel(float *src, float *dst) {
5253
size_t lid = syclext::this_work_item::get_nd_item<1>().get_local_linear_id();
53-
float *scratch_mem = (float *)syclexp::get_work_group_scratch_memory();
54-
syclexp::work_group_static<float[SIZE]> static_mem;
55-
scratch_mem[lid] = src[lid];
56-
static_mem[lid] = src[lid];
57-
dst[lid] = scratch_mem[lid] + static_mem[lid];
54+
float *scratchMem =
55+
reinterpret_cast<float *>(syclexp::get_work_group_scratch_memory());
56+
syclexp::work_group_static<float[SIZE]> staticMem;
57+
scratchMem[lid] = src[lid];
58+
staticMem[lid] = src[lid];
59+
dst[lid] = scratchMem[lid] + staticMem[lid];
5860
}
5961

6062
int main() {
@@ -66,41 +68,41 @@ int main() {
6668
src[i] = i;
6769
}
6870

69-
auto scratchbndl = syclexp::get_kernel_bundle<scratch_kernel,
70-
sycl::bundle_state::executable>(
71-
q.get_context());
72-
auto staticbndl =
73-
syclexp::get_kernel_bundle<static_kernel, sycl::bundle_state::executable>(
71+
auto scratchBndl =
72+
syclexp::get_kernel_bundle<scratchKernel, sycl::bundle_state::executable>(
73+
q.get_context());
74+
auto staticBndl =
75+
syclexp::get_kernel_bundle<staticKernel, sycl::bundle_state::executable>(
7476
q.get_context());
75-
auto scratchstaticbndl = syclexp::get_kernel_bundle<
76-
scratch_static_kernel, sycl::bundle_state::executable>(q.get_context());
77-
78-
sycl::kernel ScratchKernel =
79-
scratchbndl.template ext_oneapi_get_kernel<scratch_kernel>();
80-
sycl::kernel StaticKernel =
81-
staticbndl.template ext_oneapi_get_kernel<static_kernel>();
82-
sycl::kernel ScratchStaticKernel =
83-
scratchstaticbndl.template ext_oneapi_get_kernel<scratch_static_kernel>();
84-
syclexp::launch_config ScratchKernelcfg{
77+
auto scratchStaticBndl = syclexp::get_kernel_bundle<
78+
scratchStaticKernel, sycl::bundle_state::executable>(q.get_context());
79+
80+
sycl::kernel scratchKrn =
81+
scratchBndl.template ext_oneapi_get_kernel<scratchKernel>();
82+
sycl::kernel staticKrn =
83+
staticBndl.template ext_oneapi_get_kernel<staticKernel>();
84+
sycl::kernel scratchStaticKrn =
85+
scratchStaticBndl.template ext_oneapi_get_kernel<scratchStaticKernel>();
86+
syclexp::launch_config scratchKernelcfg{
8587
::sycl::nd_range<1>(::sycl::range<1>(SIZE), ::sycl::range<1>(SIZE)),
8688
syclexp::properties{
8789
syclexp::work_group_scratch_size(SIZE * sizeof(float))}};
88-
syclexp::launch_config StaticKernelcfg{
90+
syclexp::launch_config staticKernelcfg{
8991
::sycl::nd_range<1>(::sycl::range<1>(SIZE), ::sycl::range<1>(SIZE))};
9092

91-
syclexp::nd_launch(q, ScratchKernelcfg, ScratchKernel, src, dst);
93+
syclexp::nd_launch(q, scratchKernelcfg, scratchKrn, src, dst);
9294
q.wait();
9395
for (int i = 0; i < SIZE; i++) {
9496
assert(dst[i] == 2 * src[i]);
9597
}
9698

97-
syclexp::nd_launch(q, StaticKernelcfg, StaticKernel, src, dst);
99+
syclexp::nd_launch(q, staticKernelcfg, staticKrn, src, dst);
98100
q.wait();
99101
for (int i = 0; i < SIZE; i++) {
100102
assert(dst[i] == src[i] * src[i]);
101103
}
102104

103-
syclexp::nd_launch(q, ScratchKernelcfg, ScratchStaticKernel, src, dst);
105+
syclexp::nd_launch(q, scratchKernelcfg, scratchStaticKrn, src, dst);
104106
q.wait();
105107
for (int i = 0; i < SIZE; i++) {
106108
assert(dst[i] == 2 * src[i]);

0 commit comments

Comments
 (0)