Skip to content

Commit cc9f060

Browse files
committed
Use accessor instead of device-side assert
Signed-off-by: Michael Aziz <[email protected]>
1 parent e0f4e93 commit cc9f060

File tree

1 file changed

+51
-45
lines changed

1 file changed

+51
-45
lines changed
Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
// REQUIRES: cpu
2-
31
// RUN: %{build} %cxx_std_optionc++23 -o %t.out
42
// RUN: %{run} %t.out
53

@@ -46,91 +44,99 @@ static_assert(khr::work_group<3>::dimensions == 3);
4644
static_assert(khr::work_group<3>::fence_scope == memory_scope::work_group);
4745

4846
int main() {
49-
queue q(cpu_selector_v);
47+
queue q;
48+
sycl::buffer<bool, 1> result(1);
5049

5150
const int sz = 16;
5251
q.submit([&](handler &h) {
52+
sycl::accessor<bool, 0> acc{result, h};
5353
h.parallel_for(nd_range<1>{sz, sz}, [=](nd_item<1> item) {
5454
group<1> g = item.get_group();
5555

5656
khr::work_group<1> wg = g;
57-
assert(wg.id() == g.get_group_id());
58-
assert(wg.linear_id() == g.get_group_linear_id());
59-
assert(wg.range() == g.get_group_range());
57+
acc = true;
58+
acc &= (wg.id() == g.get_group_id());
59+
acc &= (wg.linear_id() == g.get_group_linear_id());
60+
acc &= (wg.range() == g.get_group_range());
6061
#if defined(__cpp_lib_mdspan)
61-
assert(wg.extents().rank() == 1);
62-
assert(wg.extent(0) == g.get_local_range()[0]);
62+
acc &= (wg.extents().rank() == 1);
63+
acc &= (wg.extent(0) == g.get_local_range()[0]);
6364
#endif
64-
assert(wg.size() == g.get_local_linear_range());
65+
acc &= (wg.size() == g.get_local_linear_range());
6566

6667
khr::member_item wi = get_member_item(wg);
67-
assert(wi.id() == g.get_local_id());
68-
assert(wi.linear_id() == g.get_local_linear_id());
69-
assert(wi.range() == g.get_local_range());
68+
acc &= (wi.id() == g.get_local_id());
69+
acc &= (wi.linear_id() == g.get_local_linear_id());
70+
acc &= (wi.range() == g.get_local_range());
7071
#if defined(__cpp_lib_mdspan)
71-
assert(wi.extents().rank() == 1);
72-
assert(wi.extent(0) == 1);
72+
acc &= (wi.extents().rank() == 1);
73+
acc &= (wi.extent(0) == 1);
7374
#endif
74-
assert(wi.size() == 1);
75+
acc &= (wi.size() == 1);
7576
});
7677
});
7778
q.submit([&](handler &h) {
79+
sycl::accessor<bool, 0> acc{result, h};
7880
h.parallel_for(nd_range<2>{range<2>{sz, sz}, range<2>{sz, sz}},
7981
[=](nd_item<2> item) {
8082
group<2> g = item.get_group();
8183

8284
khr::work_group<2> wg = g;
83-
assert(wg.id() == g.get_group_id());
84-
assert(wg.linear_id() == g.get_group_linear_id());
85-
assert(wg.range() == g.get_group_range());
85+
acc &= (wg.id() == g.get_group_id());
86+
acc &= (wg.linear_id() == g.get_group_linear_id());
87+
acc &= (wg.range() == g.get_group_range());
8688
#if defined(__cpp_lib_mdspan)
87-
assert(wg.extents().rank() == 2);
88-
assert(wg.extent(0) == g.get_local_range()[0]);
89-
assert(wg.extent(1) == g.get_local_range()[1]);
89+
acc &= (wg.extents().rank() == 2);
90+
acc &= (wg.extent(0) == g.get_local_range()[0]);
91+
acc &= (wg.extent(1) == g.get_local_range()[1]);
9092
#endif
91-
assert(wg.size() == g.get_local_linear_range());
93+
acc &= (wg.size() == g.get_local_linear_range());
9294

9395
khr::member_item wi = get_member_item(wg);
94-
assert(wi.id() == g.get_local_id());
95-
assert(wi.linear_id() == g.get_local_linear_id());
96-
assert(wi.range() == g.get_local_range());
96+
acc &= (wi.id() == g.get_local_id());
97+
acc &= (wi.linear_id() == g.get_local_linear_id());
98+
acc &= (wi.range() == g.get_local_range());
9799
#if defined(__cpp_lib_mdspan)
98-
assert(wi.extents().rank() == 2);
99-
assert(wi.extent(0) == 1);
100-
assert(wi.extent(1) == 1);
100+
acc &= (wi.extents().rank() == 2);
101+
acc &= (wi.extent(0) == 1);
102+
acc &= (wi.extent(1) == 1);
101103
#endif
102-
assert(wi.size() == 1);
104+
acc &= (wi.size() == 1);
103105
});
104106
});
105107
q.submit([&](handler &h) {
108+
sycl::accessor<bool, 0> acc{result, h};
106109
h.parallel_for(nd_range<3>{range<3>{sz, sz, sz}, range<3>{sz, sz, sz}},
107110
[=](nd_item<3> item) {
108111
group<3> g = item.get_group();
109112

110113
khr::work_group<3> wg = g;
111-
assert(wg.id() == g.get_group_id());
112-
assert(wg.linear_id() == g.get_group_linear_id());
113-
assert(wg.range() == g.get_group_range());
114+
acc &= (wg.id() == g.get_group_id());
115+
acc &= (wg.linear_id() == g.get_group_linear_id());
116+
acc &= (wg.range() == g.get_group_range());
114117
#if defined(__cpp_lib_mdspan)
115-
assert(wg.extents().rank() == 3);
116-
assert(wg.extent(0) == g.get_local_range()[0]);
117-
assert(wg.extent(1) == g.get_local_range()[1]);
118-
assert(wg.extent(2) == g.get_local_range()[2]);
118+
acc &= (wg.extents().rank() == 3);
119+
acc &= (wg.extent(0) == g.get_local_range()[0]);
120+
acc &= (wg.extent(1) == g.get_local_range()[1]);
121+
acc &= (wg.extent(2) == g.get_local_range()[2]);
119122
#endif
120-
assert(wg.size() == g.get_local_linear_range());
123+
acc &= (wg.size() == g.get_local_linear_range());
121124

122125
khr::member_item wi = get_member_item(wg);
123-
assert(wi.id() == g.get_local_id());
124-
assert(wi.linear_id() == g.get_local_linear_id());
125-
assert(wi.range() == g.get_local_range());
126+
acc &= (wi.id() == g.get_local_id());
127+
acc &= (wi.linear_id() == g.get_local_linear_id());
128+
acc &= (wi.range() == g.get_local_range());
126129
#if defined(__cpp_lib_mdspan)
127-
assert(wi.extents().rank() == 3);
128-
assert(wi.extent(0) == 1);
129-
assert(wi.extent(1) == 1);
130-
assert(wi.extent(2) == 1);
130+
acc &= (wi.extents().rank() == 3);
131+
acc &= (wi.extent(0) == 1);
132+
acc &= (wi.extent(1) == 1);
133+
acc &= (wi.extent(2) == 1);
131134
#endif
132-
assert(wi.size() == 1);
135+
acc &= (wi.size() == 1);
133136
});
134137
});
135138
q.wait();
139+
140+
sycl::host_accessor<bool, 0> acc{result};
141+
assert(static_cast<bool>(acc));
136142
}

0 commit comments

Comments
 (0)