Skip to content

Commit 249ea13

Browse files
committed
[SYCL] --
Signed-off-by: Hu, Peisen <[email protected]>
1 parent 6cb0404 commit 249ea13

File tree

1 file changed

+34
-29
lines changed

1 file changed

+34
-29
lines changed

sycl/test-e2e/GroupAlgorithm/root_group.cpp

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,34 @@ void testQueriesAndProperties() {
6060
check_max_num_work_group_sync(maxWGsWithLimits);
6161
}
6262

63+
template <typename T> struct TestKernel1 {
64+
T &m_data;
65+
TestKernel1(T &data_) : m_data(data_) {}
66+
void operator()(sycl::nd_item<1> it) const {
67+
volatile float X = 1.0f;
68+
volatile float Y = 1.0f;
69+
auto root = it.ext_oneapi_get_root_group();
70+
m_data[root.get_local_id()] = root.get_local_id();
71+
sycl::group_barrier(root);
72+
// Delay half of the workgroups with extra work to check that the barrier
73+
// synchronizes the whole device.
74+
if (it.get_group(0) % 2 == 0) {
75+
X += sycl::sin(X);
76+
Y += sycl::cos(Y);
77+
}
78+
root = sycl::ext::oneapi::experimental::this_work_item::get_root_group<1>();
79+
int sum = m_data[root.get_local_id()] +
80+
m_data[root.get_local_range() - root.get_local_id() - 1];
81+
sycl::group_barrier(root);
82+
m_data[root.get_local_id()] = sum;
83+
}
84+
auto get(sycl::ext::oneapi::experimental::properties_tag) {
85+
return sycl::ext::oneapi::experimental::properties{
86+
sycl::ext::oneapi::experimental::use_root_sync};
87+
;
88+
}
89+
};
90+
6391
void testRootGroup() {
6492
sycl::queue q;
6593
const auto bundle =
@@ -70,32 +98,11 @@ void testRootGroup() {
7098
.ext_oneapi_get_info<sycl::ext::oneapi::experimental::info::
7199
kernel_queue_specific::max_num_work_groups>(
72100
q, WorkGroupSize, 0);
73-
const auto props = sycl::ext::oneapi::experimental::properties{
74-
sycl::ext::oneapi::experimental::use_root_sync};
75101
sycl::buffer<int> dataBuf{sycl::range{maxWGs * WorkGroupSize}};
76102
const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize};
77103
q.submit([&](sycl::handler &h) {
78104
sycl::accessor data{dataBuf, h};
79-
h.parallel_for<
80-
class RootGroupKernel>(range, props, [=](sycl::nd_item<1> it) {
81-
volatile float X = 1.0f;
82-
volatile float Y = 1.0f;
83-
auto root = it.ext_oneapi_get_root_group();
84-
data[root.get_local_id()] = root.get_local_id();
85-
sycl::group_barrier(root);
86-
// Delay half of the workgroups with extra work to check that the barrier
87-
// synchronizes the whole device.
88-
if (it.get_group(0) % 2 == 0) {
89-
X += sycl::sin(X);
90-
Y += sycl::cos(Y);
91-
}
92-
root =
93-
sycl::ext::oneapi::experimental::this_work_item::get_root_group<1>();
94-
int sum = data[root.get_local_id()] +
95-
data[root.get_local_range() - root.get_local_id() - 1];
96-
sycl::group_barrier(root);
97-
data[root.get_local_id()] = sum;
98-
});
105+
h.parallel_for<class RootGroupKernel>(range, TestKernel1(data));
99106
});
100107
sycl::host_accessor data{dataBuf};
101108
const int workItemCount = static_cast<int>(range.get_global_range().size());
@@ -104,9 +111,9 @@ void testRootGroup() {
104111
}
105112
}
106113

107-
template <typename T> class RootGroupFunctionsKernel {
108-
public:
109-
RootGroupFunctionsKernel(T &testResults_) : m_testResults(testResults_) {}
114+
template <typename T> struct TestKernel2 {
115+
T m_testResults;
116+
TestKernel2(T &testResults_) : m_testResults(testResults_) {}
110117
void operator()(sycl::nd_item<1> it) const {
111118
const auto root = it.ext_oneapi_get_root_group();
112119
if (root.leader() || root.get_local_id() == 3) {
@@ -128,9 +135,6 @@ template <typename T> class RootGroupFunctionsKernel {
128135
return sycl::ext::oneapi::experimental::properties{
129136
sycl::ext::oneapi::experimental::use_root_sync};
130137
}
131-
132-
private:
133-
T m_testResults;
134138
};
135139

136140
void testRootGroupFunctions() {
@@ -149,7 +153,8 @@ void testRootGroupFunctions() {
149153
const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize};
150154
q.submit([&](sycl::handler &h) {
151155
sycl::accessor testResults{testResultsBuf, h};
152-
h.parallel_for(range, RootGroupFunctionsKernel(testResults));
156+
h.parallel_for<class RootGroupFunctionsKernel>(range,
157+
TestKernel2(testResults));
153158
});
154159
sycl::host_accessor testResults{testResultsBuf};
155160
for (int i = 0; i < testCount; i++) {

0 commit comments

Comments
 (0)