Skip to content

Commit 987f44c

Browse files
committed
[OFFLOAD] Add support for indexed per-thread containers
Split from #158900 it adds a PerThreadContainer that can use STL-like indexed containers based on a slightly refactored PerThreadTable.
1 parent 2b135b9 commit 987f44c

File tree

2 files changed

+154
-9
lines changed

2 files changed

+154
-9
lines changed

offload/include/OpenMP/InteropAPI.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,11 @@ struct InteropTableEntry {
160160
Interops.push_back(obj);
161161
}
162162

163-
template <class ClearFuncTy> void clear(ClearFuncTy f) {
164-
for (auto &Obj : Interops) {
165-
f(Obj);
166-
}
167-
}
168-
169163
/// vector interface
170164
int size() const { return Interops.size(); }
171165
iterator begin() { return Interops.begin(); }
172166
iterator end() { return Interops.end(); }
173-
iterator erase(iterator it) { return Interops.erase(it); }
167+
void clear() { Interops.clear(); }
174168
};
175169

176170
struct InteropTblTy

offload/include/PerThreadTable.h

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,93 @@
1616
#include <list>
1717
#include <memory>
1818
#include <mutex>
19+
#include <type_traits>
20+
21+
template <typename ObjectType> struct PerThread {
22+
struct PerThreadData {
23+
std::unique_ptr<ObjectType> ThreadEntry;
24+
};
25+
26+
std::mutex Mutex;
27+
std::list<std::shared_ptr<PerThreadData>> ThreadDataList;
28+
29+
// define default constructors, disable copy and move constructors
30+
PerThread() = default;
31+
PerThread(const PerThread &) = delete;
32+
PerThread(PerThread &&) = delete;
33+
PerThread &operator=(const PerThread &) = delete;
34+
PerThread &operator=(PerThread &&) = delete;
35+
~PerThread() {
36+
std::lock_guard<std::mutex> Lock(Mutex);
37+
ThreadDataList.clear();
38+
}
39+
40+
private:
41+
PerThreadData &getThreadData() {
42+
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
43+
if (!ThreadData) {
44+
ThreadData = std::make_shared<PerThreadData>();
45+
std::lock_guard<std::mutex> Lock(Mutex);
46+
ThreadDataList.push_back(ThreadData);
47+
}
48+
return *ThreadData;
49+
}
50+
51+
protected:
52+
ObjectType &getThreadEntry() {
53+
auto &ThData = getThreadData();
54+
if (ThData.ThEntry)
55+
return *ThData.ThEntry;
56+
ThData.ThEntry = std::make_unique<ObjectType>();
57+
return *ThData.ThEntry;
58+
}
59+
60+
public:
61+
ObjectType &get() { return getThreadEntry(); }
62+
63+
template <class F> void clear(F f) {
64+
std::lock_guard<std::mutex> Lock(Mutex);
65+
for (auto ThData : ThreadDataList) {
66+
if (!ThData->ThEntry)
67+
continue;
68+
f(*ThData->ThEntry);
69+
}
70+
ThreadDataList.clear();
71+
}
72+
};
1973

2074
// Using an STL container (such as std::vector) indexed by thread ID has
2175
// too many race conditions issues so we store each thread entry into a
2276
// thread_local variable.
2377
// T is the container type used to store the objects, e.g., std::vector,
2478
// std::set, etc. by each thread. O is the type of the stored objects e.g.,
2579
// omp_interop_val_t *, ...
26-
2780
template <typename ContainerType, typename ObjectType> struct PerThreadTable {
2881
using iterator = typename ContainerType::iterator;
2982

83+
template <typename, typename = std::void_t<>>
84+
struct has_iterator : std::false_type {};
85+
template <typename T>
86+
struct has_iterator<T, std::void_t<typename T::iterator>> : std::true_type {};
87+
88+
template <typename T, typename = std::void_t<>>
89+
struct has_clear : std::false_type {};
90+
template <typename T>
91+
struct has_clear<T, std::void_t<decltype(std::declval<T>().clear())>>
92+
: std::true_type {};
93+
94+
template <typename T, typename = std::void_t<>>
95+
struct has_clearAll : std::false_type {};
96+
template <typename T>
97+
struct has_clearAll<T, std::void_t<decltype(std::declval<T>().clearAll(1))>>
98+
: std::true_type {};
99+
100+
template <typename, typename = std::void_t<>>
101+
struct is_associative : std::false_type {};
102+
template <typename T>
103+
struct is_associative<T, std::void_t<typename T::mapped_type>>
104+
: std::true_type {};
105+
30106
struct PerThreadData {
31107
size_t NElements = 0;
32108
std::unique_ptr<ContainerType> ThEntry;
@@ -71,6 +147,11 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
71147
return ThData.NElements;
72148
}
73149

150+
void setNElements(size_t Size) {
151+
auto &NElements = getThreadNElements();
152+
NElements = Size;
153+
}
154+
74155
public:
75156
void add(ObjectType obj) {
76157
auto &Entry = getThreadEntry();
@@ -104,11 +185,81 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
104185
for (auto ThData : ThreadDataList) {
105186
if (!ThData->ThEntry || ThData->NElements == 0)
106187
continue;
107-
ThData->ThEntry->clear(f);
188+
if constexpr (has_clearAll<ContainerType>::value) {
189+
ThData->ThEntry->clearAll(f);
190+
} else if constexpr (has_iterator<ContainerType>::value &&
191+
has_clear<ContainerType>::value) {
192+
for (auto &Obj : *ThData->ThEntry) {
193+
if constexpr (is_associative<ContainerType>::value) {
194+
f(Obj.second);
195+
} else {
196+
f(Obj);
197+
}
198+
}
199+
ThData->ThEntry->clear();
200+
} else {
201+
static_assert(true, "Container type not supported");
202+
}
108203
ThData->NElements = 0;
109204
}
110205
ThreadDataList.clear();
111206
}
112207
};
113208

209+
template <typename T, typename = std::void_t<>> struct ContainerValueType {
210+
using type = typename T::value_type;
211+
};
212+
template <typename T>
213+
struct ContainerValueType<T, std::void_t<typename T::mapped_type>> {
214+
using type = typename T::mapped_type;
215+
};
216+
217+
template <typename ContainerType, size_t reserveSize = 0>
218+
struct PerThreadContainer
219+
: public PerThreadTable<ContainerType,
220+
typename ContainerValueType<ContainerType>::type> {
221+
222+
// helpers
223+
template <typename T, typename = std::void_t<>> struct indexType {
224+
using type = typename T::size_type;
225+
};
226+
template <typename T> struct indexType<T, std::void_t<typename T::key_type>> {
227+
using type = typename T::key_type;
228+
};
229+
template <typename T, typename = std::void_t<>>
230+
struct has_resize : std::false_type {};
231+
template <typename T>
232+
struct has_resize<T, std::void_t<decltype(std::declval<T>().resize(1))>>
233+
: std::true_type {};
234+
235+
template <typename T, typename = std::void_t<>>
236+
struct has_reserve : std::false_type {};
237+
template <typename T>
238+
struct has_reserve<T, std::void_t<decltype(std::declval<T>().reserve(1))>>
239+
: std::true_type {};
240+
241+
using IndexType = typename indexType<ContainerType>::type;
242+
using ObjectType = typename ContainerValueType<ContainerType>::type;
243+
244+
// Get the object for the given index in the current thread
245+
ObjectType &get(IndexType Index) {
246+
auto &Entry = this->getThreadEntry();
247+
248+
// specialized code for vector-like containers
249+
if constexpr (has_resize<ContainerType>::value) {
250+
if (Index >= Entry.size()) {
251+
if constexpr (has_reserve<ContainerType>::value && reserveSize > 0) {
252+
if (Entry.capacity() < reserveSize)
253+
Entry.reserve(reserveSize);
254+
}
255+
// If the index is out of bounds, try resize the container
256+
Entry.resize(Index + 1);
257+
}
258+
}
259+
ObjectType &Ret = Entry[Index];
260+
this->setNElements(Entry.size());
261+
return Ret;
262+
}
263+
};
264+
114265
#endif

0 commit comments

Comments
 (0)