Skip to content

Commit c9ad9ab

Browse files
committed
Add tests for work group memory extension
1 parent 5d5ec9e commit c9ad9ab

File tree

5 files changed

+731
-15
lines changed

5 files changed

+731
-15
lines changed

sycl/test-e2e/WorkGroupMemory/swap_test.cpp renamed to sycl/test-e2e/WorkGroupMemory/basic_usage.cpp

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <sycl/detail/core.hpp>
66
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
77
#include <sycl/group_barrier.hpp>
8+
#include <sycl/half_type.hpp>
9+
810
namespace syclexp = sycl::ext::oneapi::experimental;
911

1012
sycl::queue q;
@@ -50,7 +52,9 @@ template <typename T> void swap_scalar(T &a, T &b) {
5052
sycl::nd_range<1> ndr{size, wgsize};
5153
cgh.parallel_for(ndr, [=](sycl::nd_item<1> it) {
5254
syclexp::work_group_memory<T> temp2;
53-
temp2 = temp; // temp and temp2 have the same underlying data
55+
temp2 = temp; // temp and temp2 have the same underlying data
56+
assert(&temp2 == &temp); // check that both objects return same
57+
// underlying address after assignment
5458
temp = acc_a[0];
5559
acc_a[0] = acc_b[0];
5660
acc_b[0] = temp2; // safe to use temp2
@@ -86,6 +90,8 @@ template <typename T> void swap_scalar(T &a, T &b) {
8690
assert(a == old_b && b == old_a && "Incorrect swap!");
8791

8892
// Same as above but instead of using multi_ptr, use address-of operator.
93+
// Also verify that get_multi_ptr() returns the same address as address-of
94+
// operator.
8995
{
9096
sycl::buffer<T, 1> buf_a{&a, 1};
9197
sycl::buffer<T, 1> buf_b{&b, 1};
@@ -96,6 +102,7 @@ template <typename T> void swap_scalar(T &a, T &b) {
96102
syclexp::work_group_memory<T> temp2{cgh};
97103
sycl::nd_range<1> ndr{size, wgsize};
98104
cgh.parallel_for(ndr, [=](sycl::nd_item<> it) {
105+
assert(&temp == temp.get_multi_ptr().get());
99106
temp = acc_a[0];
100107
acc_a[0] = acc_b[0];
101108
temp2 = *(&temp);
@@ -294,6 +301,8 @@ void swap_array_2d(T (&a)[N][N], T (&b)[N][N], size_t batch_size) {
294301
temp[i][j] = acc_a[i][j];
295302
acc_a[i][j] = acc_b[i][j];
296303
syclexp::work_group_memory<T[N][N]> temp2{temp};
304+
assert(&temp2 == &temp); // check both objects return same underlying
305+
// address after copy construction.
297306
acc_b[i][j] = temp2[i][j];
298307
});
299308
});
@@ -342,28 +351,28 @@ void swap_array_2d(T (&a)[N][N], T (&b)[N][N], size_t batch_size) {
342351
// so we can verify that each work-item sees the value written by its leader.
343352
// The test also is a sanity check that different work groups get different
344353
// work group memory locations as otherwise we'd have data races.
345-
void coherency(size_t size, size_t wgsize) {
354+
template <typename T> void coherency(size_t size, size_t wgsize) {
346355
q.submit([&](sycl::handler &cgh) {
347-
syclexp::work_group_memory<int> data{cgh};
356+
syclexp::work_group_memory<T> data{cgh};
348357
sycl::nd_range<1> ndr{size, wgsize};
349358
cgh.parallel_for(ndr, [=](sycl::nd_item<1> it) {
350359
if (it.get_group().leader()) {
351-
data = it.get_global_id() / wgsize;
360+
data = T(it.get_global_id() / wgsize);
352361
}
353362
sycl::group_barrier(it.get_group());
354-
assert(data == it.get_global_id() / wgsize);
363+
assert(data == T(it.get_global_id() / wgsize));
355364
});
356365
});
357366
}
358367

359368
constexpr size_t N = 32;
360-
int main() {
361-
int intarr1[N][N];
362-
int intarr2[N][N];
369+
template <typename T> void test() {
370+
T intarr1[N][N];
371+
T intarr2[N][N];
363372
for (int i = 0; i < N; ++i) {
364373
for (int j = 0; j < N; ++j) {
365-
intarr1[i][j] = i + j;
366-
intarr2[i][j] = i * j;
374+
intarr1[i][j] = T(i) + T(j);
375+
intarr2[i][j] = T(i) * T(j);
367376
}
368377
}
369378
for (int i = 0; i < N; ++i) {
@@ -373,10 +382,37 @@ int main() {
373382
swap_array_1d(intarr1[i], intarr2[i], 8);
374383
}
375384
swap_array_2d(intarr1, intarr2, 8);
376-
coherency(N, N / 2);
377-
coherency(N, N / 4);
378-
coherency(N, N / 8);
379-
coherency(N, N / 16);
380-
coherency(N, N / 32);
385+
coherency<T>(N, N / 2);
386+
coherency<T>(N, N / 4);
387+
coherency<T>(N, N / 8);
388+
coherency<T>(N, N / 16);
389+
coherency<T>(N, N / 32);
390+
}
391+
392+
template <typename T> void test_ptr() {
393+
T arr1[N][N];
394+
T arr2[N][N];
395+
for (int i = 0; i < N; ++i) {
396+
for (int j = 0; j < N; ++j) {
397+
swap_scalar(arr1[i][j], arr2[i][j]);
398+
}
399+
swap_array_1d(arr1[i], arr2[i], 8);
400+
}
401+
swap_array_2d(arr1, arr2, 8);
402+
}
403+
404+
int main() {
405+
test<int>();
406+
test<char>();
407+
test<uint16_t>();
408+
if (q.get_device().has(sycl::aspect::fp16))
409+
test<sycl::half>();
410+
test_ptr<float *>();
411+
test_ptr<int *>();
412+
test_ptr<char *>();
413+
test_ptr<uint16_t *>();
414+
if (q.get_device().has(sycl::aspect::fp16))
415+
test_ptr<sycl::half *>();
416+
test_ptr<float *>();
381417
return 0;
382418
}
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
// REQUIRES: aspect-usm_shared_allocations
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
5+
// The name mangling for free function kernels currently does not work with PTX.
6+
// UNSUPPORTED: cuda
7+
8+
// Usage of work group memory parameters in free function kernels is not yet
9+
// implemented.
10+
// TODO: Remove the following directive once
11+
// https://github.com/intel/llvm/pull/15861 is merged.
12+
// XFAIL: *
13+
// XFAIL-TRACKER: https://github.com/intel/llvm/issues/15927
14+
15+
#include <cassert>
16+
#include <sycl/detail/core.hpp>
17+
#include <sycl/ext/intel/math.hpp>
18+
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
19+
#include <sycl/ext/oneapi/free_function_queries.hpp>
20+
#include <sycl/group_barrier.hpp>
21+
#include <sycl/usm.hpp>
22+
23+
using namespace sycl;
24+
25+
// Basic usage reduction test using free function kernels.
26+
// A global buffer is allocated using USM and it is passed to the kernel on the
27+
// device. On the device, a work group memory buffer is allocated and each item
28+
// copies the correspondng element of the global buffer to the corresponding
29+
// element of the work group memory buffer using its global index. The leader of
30+
// every work-group, after waiting for every work-item to complete, then sums
31+
// these values storing the result in another work group memory object. Finally,
32+
// each work item then verifies that the sum of the work group memory elements
33+
// equals the sum of the global buffer elements. This is repeated for several
34+
// data types.
35+
36+
queue q;
37+
context ctx = q.get_context();
38+
39+
constexpr size_t SIZE = 128;
40+
constexpr size_t VEC_SIZE = 16;
41+
constexpr float tolerance = 0.01f;
42+
43+
template <typename T>
44+
void sum_helper(sycl::ext::oneapi::experimental::work_group_memory<T[]> mem,
45+
sycl::ext::oneapi::experimental::work_group_memory<T> ret,
46+
size_t WGSIZE) {
47+
for (int i = 0; i < WGSIZE; ++i) {
48+
ret = ret + mem[i];
49+
}
50+
}
51+
52+
template <typename T>
53+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
54+
(ext::oneapi::experimental::nd_range_kernel<1>))
55+
void sum(sycl::ext::oneapi::experimental::work_group_memory<T[]> mem, T *buf,
56+
sycl::ext::oneapi::experimental::work_group_memory<T> result,
57+
T expected, size_t WGSIZE, bool UseHelper) {
58+
const auto it = sycl::ext::oneapi::this_work_item::get_nd_item<1>();
59+
size_t local_id = it.get_local_id();
60+
mem[local_id] = buf[local_id];
61+
group_barrier(it.get_group());
62+
if (it.get_group().leader()) {
63+
result = 0;
64+
if (!UseHelper) {
65+
for (int i = 0; i < WGSIZE; ++i) {
66+
result = result + mem[i];
67+
}
68+
} else {
69+
sum_helper(mem, result, WGSIZE);
70+
}
71+
assert(result == expected);
72+
}
73+
}
74+
75+
// Explicit instantiations for the relevant data types.
76+
#define SUM(T) \
77+
template void sum<T>( \
78+
sycl::ext::oneapi::experimental::work_group_memory<T[]> mem, T * buf, \
79+
sycl::ext::oneapi::experimental::work_group_memory<T> result, \
80+
T expected, size_t WGSIZE, bool UseHelper);
81+
82+
SUM(int)
83+
SUM(uint16_t)
84+
SUM(half)
85+
SUM(double)
86+
SUM(float)
87+
SUM(char)
88+
SUM(bool)
89+
90+
template <typename T>
91+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
92+
(ext::oneapi::experimental::nd_range_kernel<1>))
93+
void sum_marray(
94+
sycl::ext::oneapi::experimental::work_group_memory<sycl::marray<T, 16>> mem,
95+
T *buf, sycl::ext::oneapi::experimental::work_group_memory<T> result,
96+
T expected) {
97+
const auto it = sycl::ext::oneapi::this_work_item::get_nd_item<1>();
98+
size_t local_id = it.get_local_id();
99+
constexpr float tolerance = 0.01f;
100+
sycl::marray<T, 16> &data = mem;
101+
data[local_id] = buf[local_id];
102+
group_barrier(it.get_group());
103+
if (it.get_group().leader()) {
104+
result = 0;
105+
for (int i = 0; i < 16; ++i) {
106+
result = result + data[i];
107+
}
108+
assert((result - expected) * (result - expected) <= tolerance);
109+
}
110+
}
111+
112+
// Explicit instantiations for the relevant data types.
113+
#define SUM_MARRAY(T) \
114+
template void sum_marray<T>( \
115+
sycl::ext::oneapi::experimental::work_group_memory<sycl::marray<T, 16>> \
116+
mem, \
117+
T * buf, sycl::ext::oneapi::experimental::work_group_memory<T> result, \
118+
T expected);
119+
120+
SUM_MARRAY(int);
121+
SUM_MARRAY(float);
122+
SUM_MARRAY(double);
123+
SUM_MARRAY(char);
124+
SUM_MARRAY(bool);
125+
SUM_MARRAY(half);
126+
127+
template <typename T>
128+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
129+
(ext::oneapi::experimental::nd_range_kernel<1>))
130+
void sum_vec(
131+
sycl::ext::oneapi::experimental::work_group_memory<sycl::vec<T, 16>> mem,
132+
T *buf, sycl::ext::oneapi::experimental::work_group_memory<T> result,
133+
T expected) {
134+
const auto it = sycl::ext::oneapi::this_work_item::get_nd_item<1>();
135+
size_t local_id = it.get_local_id();
136+
constexpr float tolerance = 0.01f;
137+
sycl::vec<T, 16> &data = mem;
138+
data[local_id] = buf[local_id];
139+
group_barrier(it.get_group());
140+
if (it.get_group().leader()) {
141+
result = 0;
142+
for (int i = 0; i < 16; ++i) {
143+
result = result + data[i];
144+
}
145+
assert((result - expected) * (result - expected) <= tolerance);
146+
}
147+
}
148+
149+
// Explicit instantiations for the relevant data types.
150+
#define SUM_VEC(T) \
151+
template void sum_vec<T>( \
152+
sycl::ext::oneapi::experimental::work_group_memory<sycl::vec<T, 16>> \
153+
mem, \
154+
T * buf, sycl::ext::oneapi::experimental::work_group_memory<T> result, \
155+
T expected);
156+
157+
SUM_VEC(int);
158+
SUM_VEC(float);
159+
SUM_VEC(double);
160+
SUM_VEC(char);
161+
SUM_VEC(bool);
162+
SUM_VEC(half);
163+
164+
template <typename T, typename... Ts> void test_marray() {
165+
if (std::is_same_v<sycl::half, T> && !q.get_device().has(sycl::aspect::fp16))
166+
return;
167+
constexpr size_t WGSIZE = VEC_SIZE;
168+
T *buf = malloc_shared<T>(WGSIZE, q);
169+
assert(buf && "Shared USM allocation failed!");
170+
T expected = 0;
171+
for (int i = 0; i < WGSIZE; ++i) {
172+
buf[i] = ext::intel::math::sqrt(T(i));
173+
expected = expected + buf[i];
174+
}
175+
nd_range ndr{{SIZE}, {WGSIZE}};
176+
#ifndef __SYCL_DEVICE_ONLY__
177+
// Get the kernel object for the "mykernel" kernel.
178+
auto Bundle = get_kernel_bundle<sycl::bundle_state::executable>(ctx);
179+
kernel_id sum_id = ext::oneapi::experimental::get_kernel_id<sum_marray<T>>();
180+
kernel k_sum = Bundle.get_kernel(sum_id);
181+
q.submit([&](sycl::handler &cgh) {
182+
ext::oneapi::experimental::work_group_memory<marray<T, WGSIZE>> mem{cgh};
183+
ext::oneapi::experimental ::work_group_memory<T> result{cgh};
184+
cgh.set_args(mem, buf, result, expected);
185+
cgh.parallel_for(ndr, k_sum);
186+
}).wait();
187+
#endif // __SYCL_DEVICE_ONLY
188+
free(buf, q);
189+
if constexpr (sizeof...(Ts))
190+
test_marray<Ts...>();
191+
}
192+
193+
template <typename T, typename... Ts> void test_vec() {
194+
if (std::is_same_v<sycl::half, T> && !q.get_device().has(sycl::aspect::fp16))
195+
return;
196+
constexpr size_t WGSIZE = VEC_SIZE;
197+
T *buf = malloc_shared<T>(WGSIZE, q);
198+
assert(buf && "Shared USM allocation failed!");
199+
T expected = 0;
200+
for (int i = 0; i < WGSIZE; ++i) {
201+
buf[i] = ext::intel::math::sqrt(T(i));
202+
expected = expected + buf[i];
203+
}
204+
nd_range ndr{{SIZE}, {WGSIZE}};
205+
#ifndef __SYCL_DEVICE_ONLY__
206+
// Get the kernel object for the "mykernel" kernel.
207+
auto Bundle = get_kernel_bundle<sycl::bundle_state::executable>(ctx);
208+
kernel_id sum_id = ext::oneapi::experimental::get_kernel_id<sum_vec<T>>();
209+
kernel k_sum = Bundle.get_kernel(sum_id);
210+
q.submit([&](sycl::handler &cgh) {
211+
ext::oneapi::experimental::work_group_memory<vec<T, WGSIZE>> mem{cgh};
212+
ext::oneapi::experimental ::work_group_memory<T> result{cgh};
213+
cgh.set_args(mem, buf, result, expected);
214+
cgh.parallel_for(ndr, k_sum);
215+
}).wait();
216+
#endif // __SYCL_DEVICE_ONLY
217+
free(buf, q);
218+
if constexpr (sizeof...(Ts))
219+
test_vec<Ts...>();
220+
}
221+
222+
template <typename T, typename... Ts>
223+
void test(size_t SIZE, size_t WGSIZE, bool UseHelper) {
224+
if (std::is_same_v<sycl::half, T> && !q.get_device().has(sycl::aspect::fp16))
225+
return;
226+
T *buf = malloc_shared<T>(WGSIZE, q);
227+
assert(buf && "Shared USM allocation failed!");
228+
T expected = 0;
229+
for (int i = 0; i < WGSIZE; ++i) {
230+
buf[i] = T(i);
231+
expected = expected + buf[i];
232+
}
233+
nd_range ndr{{SIZE}, {WGSIZE}};
234+
// The following ifndef is required due to a number of limitations of free
235+
// function kernels. See CMPLRLLVM-61498.
236+
// TODO: Remove it once these limitations are no longer there.
237+
#ifndef __SYCL_DEVICE_ONLY__
238+
// Get the kernel object for the "mykernel" kernel.
239+
auto Bundle = get_kernel_bundle<sycl::bundle_state::executable>(ctx);
240+
kernel_id sum_id = ext::oneapi::experimental::get_kernel_id<sum<T>>();
241+
kernel k_sum = Bundle.get_kernel(sum_id);
242+
q.submit([&](sycl::handler &cgh) {
243+
ext::oneapi::experimental::work_group_memory<T[]> mem{WGSIZE, cgh};
244+
ext::oneapi::experimental ::work_group_memory<T> result{cgh};
245+
cgh.set_args(mem, buf, result, expected, WGSIZE, UseHelper);
246+
cgh.parallel_for(ndr, k_sum);
247+
}).wait();
248+
249+
#endif // __SYCL_DEVICE_ONLY
250+
free(buf, q);
251+
if constexpr (sizeof...(Ts))
252+
test<Ts...>(SIZE, WGSIZE, UseHelper);
253+
}
254+
255+
int main() {
256+
test<int, uint16_t, half, double, float>(SIZE, SIZE, true /* UseHelper */);
257+
test<int, float, half>(SIZE, SIZE, false);
258+
test<int, double, char>(SIZE, SIZE / 2, false);
259+
test<int, bool, char>(SIZE, SIZE / 4, false);
260+
test_marray<float, double, half>();
261+
test_vec<float, double, half>();
262+
return 0;
263+
}

0 commit comments

Comments
 (0)