Skip to content

Commit 17c4003

Browse files
committed
Merge branch 'work_group_memory_tests' of https://github.com/lbushi25/llvm into work_group_memory_tests
2 parents b7bb745 + 83887be commit 17c4003

File tree

5 files changed

+724
-15
lines changed

5 files changed

+724
-15
lines changed

sycl/test-e2e/WorkGroupMemory/swap_test.cpp renamed to sycl/test-e2e/WorkGroupMemory/basic_usage.cpp

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <sycl/detail/core.hpp>
66
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
77
#include <sycl/group_barrier.hpp>
8+
#include <sycl/half_type.hpp>
9+
810
namespace syclexp = sycl::ext::oneapi::experimental;
911

1012
sycl::queue q;
@@ -50,7 +52,9 @@ template <typename T> void swap_scalar(T &a, T &b) {
5052
sycl::nd_range<1> ndr{size, wgsize};
5153
cgh.parallel_for(ndr, [=](sycl::nd_item<1> it) {
5254
syclexp::work_group_memory<T> temp2;
53-
temp2 = temp; // temp and temp2 have the same underlying data
55+
temp2 = temp; // temp and temp2 have the same underlying data
56+
assert(&temp2 == &temp); // check that both objects return same
57+
// underlying address after assignment
5458
temp = acc_a[0];
5559
acc_a[0] = acc_b[0];
5660
acc_b[0] = temp2; // safe to use temp2
@@ -86,6 +90,8 @@ template <typename T> void swap_scalar(T &a, T &b) {
8690
assert(a == old_b && b == old_a && "Incorrect swap!");
8791

8892
// Same as above but instead of using multi_ptr, use address-of operator.
93+
// Also verify that get_multi_ptr() returns the same address as address-of
94+
// operator.
8995
{
9096
sycl::buffer<T, 1> buf_a{&a, 1};
9197
sycl::buffer<T, 1> buf_b{&b, 1};
@@ -96,6 +102,7 @@ template <typename T> void swap_scalar(T &a, T &b) {
96102
syclexp::work_group_memory<T> temp2{cgh};
97103
sycl::nd_range<1> ndr{size, wgsize};
98104
cgh.parallel_for(ndr, [=](sycl::nd_item<> it) {
105+
assert(&temp == temp.get_multi_ptr().get());
99106
temp = acc_a[0];
100107
acc_a[0] = acc_b[0];
101108
temp2 = *(&temp);
@@ -294,6 +301,8 @@ void swap_array_2d(T (&a)[N][N], T (&b)[N][N], size_t batch_size) {
294301
temp[i][j] = acc_a[i][j];
295302
acc_a[i][j] = acc_b[i][j];
296303
syclexp::work_group_memory<T[N][N]> temp2{temp};
304+
assert(&temp2 == &temp); // check both objects return same underlying
305+
// address after copy construction.
297306
acc_b[i][j] = temp2[i][j];
298307
});
299308
});
@@ -342,28 +351,28 @@ void swap_array_2d(T (&a)[N][N], T (&b)[N][N], size_t batch_size) {
342351
// so we can verify that each work-item sees the value written by its leader.
343352
// The test also is a sanity check that different work groups get different
344353
// work group memory locations as otherwise we'd have data races.
345-
void coherency(size_t size, size_t wgsize) {
354+
template <typename T> void coherency(size_t size, size_t wgsize) {
346355
q.submit([&](sycl::handler &cgh) {
347-
syclexp::work_group_memory<int> data{cgh};
356+
syclexp::work_group_memory<T> data{cgh};
348357
sycl::nd_range<1> ndr{size, wgsize};
349358
cgh.parallel_for(ndr, [=](sycl::nd_item<1> it) {
350359
if (it.get_group().leader()) {
351-
data = it.get_global_id() / wgsize;
360+
data = T(it.get_global_id() / wgsize);
352361
}
353362
sycl::group_barrier(it.get_group());
354-
assert(data == it.get_global_id() / wgsize);
363+
assert(data == T(it.get_global_id() / wgsize));
355364
});
356365
});
357366
}
358367

359368
constexpr size_t N = 32;
360-
int main() {
361-
int intarr1[N][N];
362-
int intarr2[N][N];
369+
template <typename T> void test() {
370+
T intarr1[N][N];
371+
T intarr2[N][N];
363372
for (int i = 0; i < N; ++i) {
364373
for (int j = 0; j < N; ++j) {
365-
intarr1[i][j] = i + j;
366-
intarr2[i][j] = i * j;
374+
intarr1[i][j] = T(i) + T(j);
375+
intarr2[i][j] = T(i) * T(j);
367376
}
368377
}
369378
for (int i = 0; i < N; ++i) {
@@ -373,10 +382,37 @@ int main() {
373382
swap_array_1d(intarr1[i], intarr2[i], 8);
374383
}
375384
swap_array_2d(intarr1, intarr2, 8);
376-
coherency(N, N / 2);
377-
coherency(N, N / 4);
378-
coherency(N, N / 8);
379-
coherency(N, N / 16);
380-
coherency(N, N / 32);
385+
coherency<T>(N, N / 2);
386+
coherency<T>(N, N / 4);
387+
coherency<T>(N, N / 8);
388+
coherency<T>(N, N / 16);
389+
coherency<T>(N, N / 32);
390+
}
391+
392+
template <typename T> void test_ptr() {
393+
T arr1[N][N];
394+
T arr2[N][N];
395+
for (int i = 0; i < N; ++i) {
396+
for (int j = 0; j < N; ++j) {
397+
swap_scalar(arr1[i][j], arr2[i][j]);
398+
}
399+
swap_array_1d(arr1[i], arr2[i], 8);
400+
}
401+
swap_array_2d(arr1, arr2, 8);
402+
}
403+
404+
int main() {
405+
test<int>();
406+
test<char>();
407+
test<uint16_t>();
408+
if (q.get_device().has(sycl::aspect::fp16))
409+
test<sycl::half>();
410+
test_ptr<float *>();
411+
test_ptr<int *>();
412+
test_ptr<char *>();
413+
test_ptr<uint16_t *>();
414+
if (q.get_device().has(sycl::aspect::fp16))
415+
test_ptr<sycl::half *>();
416+
test_ptr<float *>();
381417
return 0;
382418
}
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
// REQUIRES: aspect-usm_shared_allocations
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
5+
// The name mangling for free function kernels currently does not work with PTX.
6+
// UNSUPPORTED: cuda
7+
8+
// Usage of work group memory parameters in free function kernels is not yet
9+
// implemented.
10+
// TODO: Remove the following directive once
11+
// https://github.com/intel/llvm/pull/15861 is merged.
12+
// XFAIL: *
13+
// XFAIL-TRACKER: https://github.com/intel/llvm/issues/15927
14+
15+
#include <cassert>
16+
#include <sycl/detail/core.hpp>
17+
#include <sycl/ext/intel/math.hpp>
18+
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
19+
#include <sycl/ext/oneapi/free_function_queries.hpp>
20+
#include <sycl/group_barrier.hpp>
21+
#include <sycl/usm.hpp>
22+
23+
using namespace sycl;
24+
25+
// Basic usage reduction test using free function kernels.
26+
// A global buffer is allocated using USM and it is passed to the kernel on the
27+
// device. On the device, a work group memory buffer is allocated and each item
28+
// copies the correspondng element of the global buffer to the corresponding
29+
// element of the work group memory buffer using its global index. The leader of
30+
// every work-group, after waiting for every work-item to complete, then sums
31+
// these values storing the result in another work group memory object. Finally,
32+
// each work item then verifies that the sum of the work group memory elements
33+
// equals the sum of the global buffer elements. This is repeated for several
34+
// data types.
35+
36+
queue q;
37+
context ctx = q.get_context();
38+
39+
constexpr size_t SIZE = 128;
40+
constexpr size_t VEC_SIZE = 16;
41+
42+
template <typename T>
43+
void sum_helper(sycl::ext::oneapi::experimental::work_group_memory<T[]> mem,
44+
sycl::ext::oneapi::experimental::work_group_memory<T> ret,
45+
size_t WGSIZE) {
46+
for (int i = 0; i < WGSIZE; ++i) {
47+
ret = ret + mem[i];
48+
}
49+
}
50+
51+
template <typename T>
52+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
53+
(ext::oneapi::experimental::nd_range_kernel<1>))
54+
void sum(sycl::ext::oneapi::experimental::work_group_memory<T[]> mem, T *buf,
55+
sycl::ext::oneapi::experimental::work_group_memory<T> result,
56+
T expected, size_t WGSIZE, bool UseHelper) {
57+
const auto it = sycl::ext::oneapi::this_work_item::get_nd_item<1>();
58+
size_t local_id = it.get_local_id();
59+
mem[local_id] = buf[local_id];
60+
group_barrier(it.get_group());
61+
if (it.get_group().leader()) {
62+
result = 0;
63+
if (!UseHelper) {
64+
for (int i = 0; i < WGSIZE; ++i) {
65+
result = result + mem[i];
66+
}
67+
} else {
68+
sum_helper(mem, result, WGSIZE);
69+
}
70+
assert(result == expected);
71+
}
72+
}
73+
74+
// Explicit instantiations for the relevant data types.
75+
#define SUM(T) \
76+
template void sum<T>( \
77+
sycl::ext::oneapi::experimental::work_group_memory<T[]> mem, T * buf, \
78+
sycl::ext::oneapi::experimental::work_group_memory<T> result, \
79+
T expected, size_t WGSIZE, bool UseHelper);
80+
81+
SUM(int)
82+
SUM(uint16_t)
83+
SUM(half)
84+
SUM(double)
85+
SUM(float)
86+
SUM(char)
87+
SUM(bool)
88+
89+
template <typename T>
90+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
91+
(ext::oneapi::experimental::nd_range_kernel<1>))
92+
void sum_marray(
93+
sycl::ext::oneapi::experimental::work_group_memory<sycl::marray<T, 16>> mem,
94+
T *buf, sycl::ext::oneapi::experimental::work_group_memory<T> result,
95+
T expected) {
96+
const auto it = sycl::ext::oneapi::this_work_item::get_nd_item<1>();
97+
size_t local_id = it.get_local_id();
98+
constexpr float tolerance = 0.01f;
99+
sycl::marray<T, 16> &data = mem;
100+
data[local_id] = buf[local_id];
101+
group_barrier(it.get_group());
102+
if (it.get_group().leader()) {
103+
result = 0;
104+
for (int i = 0; i < 16; ++i) {
105+
result = result + data[i];
106+
}
107+
assert((result - expected) * (result - expected) <= tolerance);
108+
}
109+
}
110+
111+
// Explicit instantiations for the relevant data types.
112+
#define SUM_MARRAY(T) \
113+
template void sum_marray<T>( \
114+
sycl::ext::oneapi::experimental::work_group_memory<sycl::marray<T, 16>> \
115+
mem, \
116+
T * buf, sycl::ext::oneapi::experimental::work_group_memory<T> result, \
117+
T expected);
118+
119+
SUM_MARRAY(float);
120+
SUM_MARRAY(double);
121+
SUM_MARRAY(half);
122+
123+
template <typename T>
124+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
125+
(ext::oneapi::experimental::nd_range_kernel<1>))
126+
void sum_vec(
127+
sycl::ext::oneapi::experimental::work_group_memory<sycl::vec<T, 16>> mem,
128+
T *buf, sycl::ext::oneapi::experimental::work_group_memory<T> result,
129+
T expected) {
130+
const auto it = sycl::ext::oneapi::this_work_item::get_nd_item<1>();
131+
size_t local_id = it.get_local_id();
132+
constexpr float tolerance = 0.01f;
133+
sycl::vec<T, 16> &data = mem;
134+
data[local_id] = buf[local_id];
135+
group_barrier(it.get_group());
136+
if (it.get_group().leader()) {
137+
result = 0;
138+
for (int i = 0; i < 16; ++i) {
139+
result = result + data[i];
140+
}
141+
assert((result - expected) * (result - expected) <= tolerance);
142+
}
143+
}
144+
145+
// Explicit instantiations for the relevant data types.
146+
#define SUM_VEC(T) \
147+
template void sum_vec<T>( \
148+
sycl::ext::oneapi::experimental::work_group_memory<sycl::vec<T, 16>> \
149+
mem, \
150+
T * buf, sycl::ext::oneapi::experimental::work_group_memory<T> result, \
151+
T expected);
152+
153+
SUM_VEC(float);
154+
SUM_VEC(double);
155+
SUM_VEC(half);
156+
157+
template <typename T, typename... Ts> void test_marray() {
158+
if (std::is_same_v<sycl::half, T> && !q.get_device().has(sycl::aspect::fp16))
159+
return;
160+
constexpr size_t WGSIZE = VEC_SIZE;
161+
T *buf = malloc_shared<T>(WGSIZE, q);
162+
assert(buf && "Shared USM allocation failed!");
163+
T expected = 0;
164+
for (int i = 0; i < WGSIZE; ++i) {
165+
buf[i] = ext::intel::math::sqrt(T(i));
166+
expected = expected + buf[i];
167+
}
168+
nd_range ndr{{SIZE}, {WGSIZE}};
169+
#ifndef __SYCL_DEVICE_ONLY__
170+
// Get the kernel object for the "mykernel" kernel.
171+
auto Bundle = get_kernel_bundle<sycl::bundle_state::executable>(ctx);
172+
kernel_id sum_id = ext::oneapi::experimental::get_kernel_id<sum_marray<T>>();
173+
kernel k_sum = Bundle.get_kernel(sum_id);
174+
q.submit([&](sycl::handler &cgh) {
175+
ext::oneapi::experimental::work_group_memory<marray<T, WGSIZE>> mem{cgh};
176+
ext::oneapi::experimental ::work_group_memory<T> result{cgh};
177+
cgh.set_args(mem, buf, result, expected);
178+
cgh.parallel_for(ndr, k_sum);
179+
}).wait();
180+
#endif // __SYCL_DEVICE_ONLY
181+
free(buf, q);
182+
if constexpr (sizeof...(Ts))
183+
test_marray<Ts...>();
184+
}
185+
186+
template <typename T, typename... Ts> void test_vec() {
187+
if (std::is_same_v<sycl::half, T> && !q.get_device().has(sycl::aspect::fp16))
188+
return;
189+
constexpr size_t WGSIZE = VEC_SIZE;
190+
T *buf = malloc_shared<T>(WGSIZE, q);
191+
assert(buf && "Shared USM allocation failed!");
192+
T expected = 0;
193+
for (int i = 0; i < WGSIZE; ++i) {
194+
buf[i] = ext::intel::math::sqrt(T(i));
195+
expected = expected + buf[i];
196+
}
197+
nd_range ndr{{SIZE}, {WGSIZE}};
198+
#ifndef __SYCL_DEVICE_ONLY__
199+
// Get the kernel object for the "mykernel" kernel.
200+
auto Bundle = get_kernel_bundle<sycl::bundle_state::executable>(ctx);
201+
kernel_id sum_id = ext::oneapi::experimental::get_kernel_id<sum_vec<T>>();
202+
kernel k_sum = Bundle.get_kernel(sum_id);
203+
q.submit([&](sycl::handler &cgh) {
204+
ext::oneapi::experimental::work_group_memory<vec<T, WGSIZE>> mem{cgh};
205+
ext::oneapi::experimental ::work_group_memory<T> result{cgh};
206+
cgh.set_args(mem, buf, result, expected);
207+
cgh.parallel_for(ndr, k_sum);
208+
}).wait();
209+
#endif // __SYCL_DEVICE_ONLY
210+
free(buf, q);
211+
if constexpr (sizeof...(Ts))
212+
test_vec<Ts...>();
213+
}
214+
215+
template <typename T, typename... Ts>
216+
void test(size_t SIZE, size_t WGSIZE, bool UseHelper) {
217+
if (std::is_same_v<sycl::half, T> && !q.get_device().has(sycl::aspect::fp16))
218+
return;
219+
T *buf = malloc_shared<T>(WGSIZE, q);
220+
assert(buf && "Shared USM allocation failed!");
221+
T expected = 0;
222+
for (int i = 0; i < WGSIZE; ++i) {
223+
buf[i] = T(i);
224+
expected = expected + buf[i];
225+
}
226+
nd_range ndr{{SIZE}, {WGSIZE}};
227+
// The following ifndef is required due to a number of limitations of free
228+
// function kernels. See CMPLRLLVM-61498.
229+
// TODO: Remove it once these limitations are no longer there.
230+
#ifndef __SYCL_DEVICE_ONLY__
231+
// Get the kernel object for the "mykernel" kernel.
232+
auto Bundle = get_kernel_bundle<sycl::bundle_state::executable>(ctx);
233+
kernel_id sum_id = ext::oneapi::experimental::get_kernel_id<sum<T>>();
234+
kernel k_sum = Bundle.get_kernel(sum_id);
235+
q.submit([&](sycl::handler &cgh) {
236+
ext::oneapi::experimental::work_group_memory<T[]> mem{WGSIZE, cgh};
237+
ext::oneapi::experimental ::work_group_memory<T> result{cgh};
238+
cgh.set_args(mem, buf, result, expected, WGSIZE, UseHelper);
239+
cgh.parallel_for(ndr, k_sum);
240+
}).wait();
241+
242+
#endif // __SYCL_DEVICE_ONLY
243+
free(buf, q);
244+
if constexpr (sizeof...(Ts))
245+
test<Ts...>(SIZE, WGSIZE, UseHelper);
246+
}
247+
248+
int main() {
249+
test<int, uint16_t, half, double, float>(SIZE, SIZE, true /* UseHelper */);
250+
test<int, float, half>(SIZE, SIZE, false);
251+
test<int, double, char>(SIZE, SIZE / 2, false);
252+
test<int, bool, char>(SIZE, SIZE / 4, false);
253+
test_marray<float, double, half>();
254+
test_vec<float, double, half>();
255+
return 0;
256+
}

0 commit comments

Comments
 (0)