Skip to content

Commit a3f8114

Browse files
committed
Add more testing
1 parent 6e814de commit a3f8114

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

sycl/test-e2e/FreeFunctionKernels/free_function_kernel_local_memory.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,22 @@ void static_kernel(float *src, float *dst) {
4545
dst[lid] = local_mem[lid];
4646
}
4747

48+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
49+
void scratch_static_kernel(float *src, float *dst) {
50+
size_t lid = syclext::this_work_item::get_nd_item<1>().get_local_linear_id();
51+
float *scratch_mem = (float *)syclexp::get_work_group_scratch_memory();
52+
syclexp::work_group_static<float[SIZE]> static_mem;
53+
scratch_mem[lid] = src[lid];
54+
static_mem[lid] = src[lid];
55+
dst[lid] = scratch_mem[lid] + static_mem[lid];
56+
}
57+
4858
int main() {
4959
sycl::queue q;
5060
float *src = sycl::malloc_shared<float>(SIZE, q);
5161
float *dst = sycl::malloc_shared<float>(SIZE, q);
5262

53-
for (int i = 1; i < SIZE; i++) {
63+
for (int i = 0; i < SIZE; i++) {
5464
src[i] = i;
5565
}
5666

@@ -60,11 +70,15 @@ int main() {
6070
auto staticbndl =
6171
syclexp::get_kernel_bundle<static_kernel, sycl::bundle_state::executable>(
6272
q.get_context());
73+
auto scratchstaticbndl = syclexp::get_kernel_bundle<
74+
scratch_static_kernel, sycl::bundle_state::executable>(q.get_context());
6375

6476
sycl::kernel ScratchKernel =
6577
scratchbndl.template ext_oneapi_get_kernel<scratch_kernel>();
6678
sycl::kernel StaticKernel =
6779
staticbndl.template ext_oneapi_get_kernel<static_kernel>();
80+
sycl::kernel ScratchStaticKernel =
81+
scratchstaticbndl.template ext_oneapi_get_kernel<scratch_static_kernel>();
6882
syclexp::launch_config ScratchKernelcfg{
6983
::sycl::nd_range<1>(::sycl::range<1>(SIZE), ::sycl::range<1>(SIZE)),
7084
syclexp::properties{
@@ -84,6 +98,12 @@ int main() {
8498
assert(dst[i] == src[i] * src[i]);
8599
}
86100

101+
syclexp::nd_launch(q, ScratchKernelcfg, ScratchStaticKernel, src, dst);
102+
q.wait();
103+
for (int i = 0; i < SIZE; i++) {
104+
assert(dst[i] == 2 * src[i]);
105+
}
106+
87107
sycl::free(src, q);
88108
sycl::free(dst, q);
89109
return 0;

0 commit comments

Comments
 (0)