Skip to content

Commit ce5fc20

Browse files
committed
[SYCL] Add prototype of sycl_khr_group_interface
Adds khr::work_group, khr::sub_group and khr::work_item classes, khr::get_item() and khr::leader_of() free-functions, and basic tests. Signed-off-by: John Pennycook <[email protected]>
1 parent 761d45d commit ce5fc20

File tree

4 files changed

+506
-0
lines changed

4 files changed

+506
-0
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
//==----- group_interface.hpp --- sycl_khr_group_interface extension -------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#pragma once
9+
10+
#include <sycl/ext/oneapi/free_function_queries.hpp>
11+
#include <sycl/id.hpp>
12+
#include <sycl/range.hpp>
13+
14+
#if __cplusplus >= 202302L && defined(__has_include)
15+
#if __has_include(<mdspan>)
16+
#include <mdspan>
17+
#endif
18+
#endif
19+
20+
namespace sycl {
21+
inline namespace _V1 {
22+
23+
namespace khr {
24+
25+
// Forward declarations for friend functions and traits.
26+
template <int Dimensions> class work_group;
27+
template class sub_group;
28+
template <typename ParentGroup> class work_item;
29+
template <typename ParentGroup> work_item<ParentGroup> get_item(ParentGroup);
30+
31+
} // namespace khr
32+
33+
namespace detail {
34+
#if defined(__cpp_lib_mdspan)
35+
template <typename IndexType, int Dimensions> struct single_extents;
36+
37+
template <typename IndexType> single_extents<1> {
38+
using type = std::extents<IndexType, 1>;
39+
}
40+
41+
template <typename IndexType> single_extents<2> {
42+
using type = std::extents<IndexType, 1, 1>;
43+
}
44+
45+
template <typename IndexType> single_extents<3> {
46+
using type = std::extents<IndexType, 1, 1, 1>;
47+
}
48+
#endif
49+
50+
template <typename T> struct is_khr_group : public std::false_type {};
51+
52+
template <int Dimensions>
53+
struct is_khr_group<khr::work_group<Dimensions>> : public std::true_type {};
54+
55+
struct is_khr_group<khr::sub_group> : public std::true_type {};
56+
57+
} // namespace detail
58+
59+
namespace khr {
60+
61+
template <int Dimensions = 1> class work_group {
62+
public:
63+
using id_type = id<Dimensions>;
64+
using linear_id_type = size_t;
65+
using range_type = range<Dimensions>;
66+
#if defined(__cpp_lib_mdspan)
67+
using extents_type = std::dextents<size_t, Dimensions>;
68+
#endif
69+
using size_type = size_t;
70+
static constexpr int dimensions = Dimensions;
71+
static constexpr memory_scope fence_scope = memory_scope::work_group;
72+
73+
work_group(group<Dimensions> g) noexcept {}
74+
75+
operator group<Dimensions>() const noexcept { return legacy(); }
76+
77+
id_type id() const noexcept { return legacy().get_group_id(); }
78+
79+
linear_id_type linear_id() const noexcept {
80+
return legacy().get_group_linear_id();
81+
}
82+
83+
range_type range() const noexcept { return legacy().get_group_range(); }
84+
85+
#if defined(__cpp_lib_mdspan)
86+
constexpr extents_type extents() const noexcept {
87+
auto LocalRange = legacy().get_local_range();
88+
if constexpr (dimensions == 1) {
89+
return extents_type(LocalRange[0]);
90+
} else if constexpr (dimensions == 2) {
91+
return extents_type(LocalRange[0], LocalRange[1]);
92+
} else if constexpr (dimensions == 3) {
93+
return extents_type(LocalRange[0], LocalRange[1], LocalRange[2]);
94+
}
95+
}
96+
97+
constexpr index_type extent(rank_type r) const noexcept {
98+
return extents().extent(r);
99+
}
100+
#endif
101+
102+
constexpr size_type size() const noexcept {
103+
return legacy().get_local_range().size();
104+
}
105+
106+
private:
107+
group<Dimensions> legacy() const noexcept {
108+
return ext::oneapi::this_work_item::get_work_group<Dimensions>();
109+
}
110+
};
111+
112+
class sub_group {
113+
public:
114+
using id_type = id<1>;
115+
using linear_id_type = uint32_t;
116+
using range_type = range<1>;
117+
#if defined(__cpp_lib_mdspan)
118+
using extents_type = std::dextents<uint32_t, 1>;
119+
#endif
120+
using size_type = uint32_t;
121+
static constexpr int dimensions = 1;
122+
static constexpr memory_scope fence_scope = memory_scope::sub_group;
123+
124+
sub_group(sycl::sub_group g) noexcept {}
125+
126+
operator sycl::sub_group() const noexcept { return legacy(); }
127+
128+
id_type id() const noexcept { return legacy().get_group_id(); }
129+
130+
linear_id_type linear_id() const noexcept {
131+
return legacy().get_group_linear_id();
132+
}
133+
134+
range_type range() const noexcept { return legacy().get_group_range(); }
135+
136+
#if defined(__cpp_lib_mdspan)
137+
constexpr extents_type extents() const noexcept {
138+
return extents_type(legacy().get_local_range()[0]);
139+
}
140+
141+
constexpr index_type extent(rank_type r) const noexcept {
142+
return extents().extent(r);
143+
}
144+
#endif
145+
146+
constexpr size_type size() const noexcept {
147+
return legacy().get_local_range()[0];
148+
}
149+
150+
constexpr size_type max_size() const noexcept {
151+
return legacy().get_max_local_range()[0];
152+
}
153+
154+
private:
155+
sycl::sub_group legacy() const noexcept {
156+
return ext::oneapi::this_work_item::get_sub_group();
157+
}
158+
};
159+
160+
template <typename ParentGroup> class work_item {
161+
public:
162+
using id_type = typename ParentGroup::id_type;
163+
using linear_id_type = typename ParentGroup::linear_id_type;
164+
using range_type = typename ParentGroup::range_type;
165+
#if defined(__cpp_lib_mdspan)
166+
using extents_type =
167+
detail::single_extents<typename ParentGroup::extents_type::index_type,
168+
ParentGroup::dimensions>;
169+
#endif
170+
using size_type = typename ParentGroup::size_type;
171+
static constexpr int dimensions = ParentGroup::dimensions;
172+
static constexpr memory_scope fence_scope = memory_scope::work_item;
173+
174+
id_type id() const noexcept { return legacy().get_local_id(); }
175+
176+
linear_id_type linear_id() const noexcept {
177+
return legacy().get_local_linear_id();
178+
}
179+
180+
range_type range() const noexcept { return legacy().get_local_range(); }
181+
182+
#if defined(__cpp_lib_mdspan)
183+
constexpr extents_type extents() const noexcept { return extents_type(); }
184+
185+
constexpr index_type extent(rank_type r) const noexcept {
186+
return extents().extent(r);
187+
}
188+
#endif
189+
190+
constexpr size_type size() const noexcept { return 1; }
191+
192+
private:
193+
auto legacy() const noexcept {
194+
if constexpr (std::is_same_v<ParentGroup, sub_group>) {
195+
return ext::oneapi::this_work_item::get_sub_group();
196+
} else {
197+
return ext::oneapi::this_work_item::get_work_group<
198+
ParentGroup::dimensions>();
199+
}
200+
}
201+
202+
protected:
203+
work_item() {}
204+
205+
friend work_item<ParentGroup> get_item<ParentGroup>(ParentGroup);
206+
};
207+
208+
template <typename ParentGroup>
209+
std::enable_if_t<detail::is_khr_group<ParentGroup>::value,
210+
work_item<ParentGroup>>
211+
get_item(ParentGroup g) {
212+
return work_item<ParentGroup>{};
213+
}
214+
215+
template <typename Group> bool leader_of(Group g) {
216+
return get_item(g).linear_id() == 0;
217+
}
218+
219+
} // namespace khr
220+
} // namespace _V1
221+
} // namespace sycl
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <cassert>
5+
#include <iostream>
6+
#include <sycl/detail/core.hpp>
7+
#include <sycl/group_algorithm.hpp>
8+
#include <sycl/khr/group_interface.hpp>
9+
10+
using namespace sycl;
11+
12+
void test(queue q) {
13+
int out = 0;
14+
size_t G = 4;
15+
16+
range<2> R(G, G);
17+
{
18+
buffer<int> out_buf(&out, 1);
19+
20+
q.submit([&](handler &cgh) {
21+
auto out = out_buf.template get_access<access::mode::read_write>(cgh);
22+
cgh.parallel_for(nd_range<2>(R, R), [=](nd_item<2> it) {
23+
khr::work_group<2> g = it.get_group();
24+
if (khr::leader_of(g)) {
25+
out[0] += 1;
26+
}
27+
});
28+
});
29+
}
30+
assert(out == 1);
31+
}
32+
33+
int main() {
34+
queue q;
35+
test(q);
36+
37+
std::cout << "Test passed." << std::endl;
38+
return 0;
39+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// REQUIRES: cpu
2+
3+
// RUN: %{build} %cxx_std_optionc++23 -o %t.out
4+
// RUN: %{run} %t.out
5+
6+
#include <sycl/detail/core.hpp>
7+
#include <sycl/khr/group_interface.hpp>
8+
9+
#include <sycl/builtins.hpp>
10+
11+
#include <type_traits>
12+
13+
using namespace sycl;
14+
15+
static_assert(std::is_same_v<khr::sub_group::id_type, id<1>>);
16+
static_assert(std::is_same_v<khr::sub_group::linear_id_type, uint32_t>);
17+
static_assert(std::is_same_v<khr::sub_group::range_type, range<1>>);
18+
#if defined(__cpp_lib_mdspan)
19+
static_assert(
20+
std::is_same_v<khr::sub_group::extents_type, std::dextents<uint32_t, 1>>);
21+
#endif
22+
static_assert(std::is_same_v<khr::sub_group::size_type, uint32_t>);
23+
static_assert(khr::sub_group::dimensions == 1);
24+
static_assert(khr::sub_group::fence_scope == memory_scope::sub_group);
25+
26+
int main() {
27+
queue q(cpu_selector_v);
28+
29+
const int sz = 16;
30+
q.submit([&](handler &h) {
31+
h.parallel_for(nd_range<1>{sz, sz}, [=](nd_item<1> item) {
32+
sub_group g = item.get_sub_group();
33+
34+
khr::sub_group sg = g;
35+
assert(sg.id() == g.get_group_id());
36+
assert(sg.linear_id() == g.get_group_linear_id());
37+
assert(sg.range() == g.get_group_range());
38+
#if defined(__cpp_lib_mdspan)
39+
assert(sg.extents().rank() == 1);
40+
assert(sg.extent(0) == g.get_local_range()[0]);
41+
#endif
42+
assert(sg.size() == g.get_local_linear_range());
43+
assert(sg.max_size() == g.get_max_local_range()[0]);
44+
45+
khr::work_item wi = get_item(sg);
46+
assert(wi.id() == g.get_local_id());
47+
assert(wi.linear_id() == g.get_local_linear_id());
48+
assert(wi.range() == g.get_local_range());
49+
#if defined(__cpp_lib_mdspan)
50+
assert(wi.extents().rank() == 1);
51+
assert(wi.extent(0) == 1);
52+
#endif
53+
assert(wi.size() == 1);
54+
});
55+
});
56+
q.submit([&](handler &h) {
57+
h.parallel_for(nd_range<2>{range<2>{sz, sz}, range<2>{sz, sz}},
58+
[=](nd_item<2> item) {
59+
sub_group g = item.get_sub_group();
60+
61+
khr::sub_group sg = g;
62+
assert(sg.id() == g.get_group_id());
63+
assert(sg.linear_id() == g.get_group_linear_id());
64+
assert(sg.range() == g.get_group_range());
65+
#if defined(__cpp_lib_mdspan)
66+
assert(sg.extents().rank() == 1);
67+
assert(sg.extent(0) == g.get_local_range()[0]);
68+
#endif
69+
assert(sg.size() == g.get_local_linear_range());
70+
assert(sg.max_size() == g.get_max_local_range()[0]);
71+
72+
khr::work_item wi = get_item(sg);
73+
assert(wi.id() == g.get_local_id());
74+
assert(wi.linear_id() == g.get_local_linear_id());
75+
assert(wi.range() == g.get_local_range());
76+
#if defined(__cpp_lib_mdspan)
77+
assert(wi.extents().rank() == 1);
78+
assert(wi.extent(0) == 1);
79+
#endif
80+
assert(wi.size() == 1);
81+
});
82+
});
83+
q.submit([&](handler &h) {
84+
h.parallel_for(nd_range<3>{range<3>{sz, sz, sz}, range<3>{sz, sz, sz}},
85+
[=](nd_item<3> item) {
86+
sub_group g = item.get_sub_group();
87+
88+
khr::sub_group sg = g;
89+
assert(sg.id() == g.get_group_id());
90+
assert(sg.linear_id() == g.get_group_linear_id());
91+
assert(sg.range() == g.get_group_range());
92+
#if defined(__cpp_lib_mdspan)
93+
assert(sg.extents().rank() == 1);
94+
assert(sg.extent(0) == g.get_local_range()[0]);
95+
#endif
96+
assert(sg.size() == g.get_local_linear_range());
97+
assert(sg.max_size() == g.get_max_local_range()[0]);
98+
99+
khr::work_item wi = get_item(sg);
100+
assert(wi.id() == g.get_local_id());
101+
assert(wi.linear_id() == g.get_local_linear_id());
102+
assert(wi.range() == g.get_local_range());
103+
#if defined(__cpp_lib_mdspan)
104+
assert(wi.extents().rank() == 1);
105+
assert(wi.extent(0) == 1);
106+
#endif
107+
});
108+
});
109+
q.wait();
110+
}

0 commit comments

Comments
 (0)