Skip to content

Commit 3f22ed1

Browse files
adurangjhuber6
andauthored
[OFFLOAD] Add support for indexed per-thread containers (llvm#164263)
Split from llvm#158900 it adds a PerThreadContainer that can use STL-like indexed containers based on a slightly refactored PerThreadTable. --------- Co-authored-by: Joseph Huber <[email protected]>
1 parent bd04ef6 commit 3f22ed1

File tree

2 files changed

+208
-59
lines changed

2 files changed

+208
-59
lines changed

offload/include/OpenMP/InteropAPI.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,11 @@ struct InteropTableEntry {
160160
Interops.push_back(obj);
161161
}
162162

163-
template <class ClearFuncTy> void clear(ClearFuncTy f) {
164-
for (auto &Obj : Interops) {
165-
f(Obj);
166-
}
167-
}
168-
169163
/// vector interface
170164
int size() const { return Interops.size(); }
171165
iterator begin() { return Interops.begin(); }
172166
iterator end() { return Interops.end(); }
173-
iterator erase(iterator it) { return Interops.erase(it); }
167+
void clear() { Interops.clear(); }
174168
};
175169

176170
struct InteropTblTy

offload/include/PerThreadTable.h

Lines changed: 207 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,101 +14,256 @@
1414
#define OFFLOAD_PERTHREADTABLE_H
1515

1616
#include <list>
17+
#include <llvm/ADT/SmallVector.h>
18+
#include <llvm/Support/Error.h>
1719
#include <memory>
1820
#include <mutex>
21+
#include <type_traits>
22+
23+
template <typename ObjectType> class PerThread {
24+
std::mutex Mutex;
25+
llvm::SmallVector<std::shared_ptr<ObjectType>> ThreadDataList;
26+
27+
ObjectType &getThreadData() {
28+
static thread_local std::shared_ptr<ObjectType> ThreadData = nullptr;
29+
if (!ThreadData) {
30+
ThreadData = std::make_shared<ObjectType>();
31+
std::lock_guard<std::mutex> Lock(Mutex);
32+
ThreadDataList.push_back(ThreadData);
33+
}
34+
return *ThreadData;
35+
}
36+
37+
public:
38+
// Define default constructors, disable copy and move constructors.
39+
PerThread() = default;
40+
PerThread(const PerThread &) = delete;
41+
PerThread(PerThread &&) = delete;
42+
PerThread &operator=(const PerThread &) = delete;
43+
PerThread &operator=(PerThread &&) = delete;
44+
~PerThread() {
45+
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
46+
"Cannot be deleted while other threads are adding entries");
47+
ThreadDataList.clear();
48+
}
49+
50+
ObjectType &get() { return getThreadData(); }
51+
52+
template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) {
53+
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
54+
"Clear cannot be called while other threads are adding entries");
55+
for (std::shared_ptr<ObjectType> ThreadData : ThreadDataList) {
56+
if (!ThreadData)
57+
continue;
58+
ClearFunc(*ThreadData);
59+
}
60+
ThreadDataList.clear();
61+
}
62+
};
63+
64+
template <typename ContainerTy> struct ContainerConcepts {
65+
template <typename, template <typename> class, typename = std::void_t<>>
66+
struct has : std::false_type {};
67+
template <typename Ty, template <typename> class Op>
68+
struct has<Ty, Op, std::void_t<Op<Ty>>> : std::true_type {};
69+
70+
template <typename Ty> using IteratorTypeCheck = typename Ty::iterator;
71+
template <typename Ty> using MappedTypeCheck = typename Ty::mapped_type;
72+
template <typename Ty> using ValueTypeCheck = typename Ty::value_type;
73+
template <typename Ty> using KeyTypeCheck = typename Ty::key_type;
74+
template <typename Ty> using SizeTypeCheck = typename Ty::size_type;
75+
76+
template <typename Ty>
77+
using ClearCheck = decltype(std::declval<Ty>().clear());
78+
template <typename Ty>
79+
using ReserveCheck = decltype(std::declval<Ty>().reserve(1));
80+
template <typename Ty>
81+
using ResizeCheck = decltype(std::declval<Ty>().resize(1));
82+
83+
static constexpr bool hasIterator =
84+
has<ContainerTy, IteratorTypeCheck>::value;
85+
static constexpr bool hasClear = has<ContainerTy, ClearCheck>::value;
86+
static constexpr bool isAssociative =
87+
has<ContainerTy, MappedTypeCheck>::value;
88+
static constexpr bool hasReserve = has<ContainerTy, ReserveCheck>::value;
89+
static constexpr bool hasResize = has<ContainerTy, ResizeCheck>::value;
90+
91+
template <typename, template <typename> class, typename = std::void_t<>>
92+
struct has_type {
93+
using type = void;
94+
};
95+
template <typename Ty, template <typename> class Op>
96+
struct has_type<Ty, Op, std::void_t<Op<Ty>>> {
97+
using type = Op<Ty>;
98+
};
99+
100+
using iterator = typename has_type<ContainerTy, IteratorTypeCheck>::type;
101+
using value_type = typename std::conditional_t<
102+
isAssociative, typename has_type<ContainerTy, MappedTypeCheck>::type,
103+
typename has_type<ContainerTy, ValueTypeCheck>::type>;
104+
using key_type = typename std::conditional_t<
105+
isAssociative, typename has_type<ContainerTy, KeyTypeCheck>::type,
106+
typename has_type<ContainerTy, SizeTypeCheck>::type>;
107+
};
19108

20109
// Using an STL container (such as std::vector) indexed by thread ID has
21110
// too many race conditions issues so we store each thread entry into a
22111
// thread_local variable.
23-
// T is the container type used to store the objects, e.g., std::vector,
24-
// std::set, etc. by each thread. O is the type of the stored objects e.g.,
25-
// omp_interop_val_t *, ...
26-
27-
template <typename ContainerType, typename ObjectType> struct PerThreadTable {
28-
using iterator = typename ContainerType::iterator;
112+
// ContainerType is the container type used to store the objects, e.g.,
113+
// std::vector, std::set, etc. by each thread. ObjectType is the type of the
114+
// stored objects e.g., omp_interop_val_t *, ...
115+
template <typename ContainerType, typename ObjectType> class PerThreadTable {
116+
using iterator = typename ContainerConcepts<ContainerType>::iterator;
29117

30118
struct PerThreadData {
31-
size_t NElements = 0;
32-
std::unique_ptr<ContainerType> ThEntry;
119+
size_t Size = 0;
120+
std::unique_ptr<ContainerType> ThreadEntry;
33121
};
34122

35-
std::mutex Mtx;
36-
std::list<std::shared_ptr<PerThreadData>> ThreadDataList;
37-
38-
// define default constructors, disable copy and move constructors
39-
PerThreadTable() = default;
40-
PerThreadTable(const PerThreadTable &) = delete;
41-
PerThreadTable(PerThreadTable &&) = delete;
42-
PerThreadTable &operator=(const PerThreadTable &) = delete;
43-
PerThreadTable &operator=(PerThreadTable &&) = delete;
44-
~PerThreadTable() {
45-
std::lock_guard<std::mutex> Lock(Mtx);
46-
ThreadDataList.clear();
47-
}
123+
std::mutex Mutex;
124+
llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList;
48125

49-
private:
50126
PerThreadData &getThreadData() {
51-
static thread_local std::shared_ptr<PerThreadData> ThData = nullptr;
52-
if (!ThData) {
53-
ThData = std::make_shared<PerThreadData>();
54-
std::lock_guard<std::mutex> Lock(Mtx);
55-
ThreadDataList.push_back(ThData);
127+
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
128+
if (!ThreadData) {
129+
ThreadData = std::make_shared<PerThreadData>();
130+
std::lock_guard<std::mutex> Lock(Mutex);
131+
ThreadDataList.push_back(ThreadData);
56132
}
57-
return *ThData;
133+
return *ThreadData;
58134
}
59135

60136
protected:
61137
ContainerType &getThreadEntry() {
62-
auto &ThData = getThreadData();
63-
if (ThData.ThEntry)
64-
return *ThData.ThEntry;
65-
ThData.ThEntry = std::make_unique<ContainerType>();
66-
return *ThData.ThEntry;
138+
PerThreadData &ThreadData = getThreadData();
139+
if (ThreadData.ThreadEntry)
140+
return *ThreadData.ThreadEntry;
141+
ThreadData.ThreadEntry = std::make_unique<ContainerType>();
142+
return *ThreadData.ThreadEntry;
143+
}
144+
145+
size_t &getThreadSize() {
146+
PerThreadData &ThreadData = getThreadData();
147+
return ThreadData.Size;
67148
}
68149

69-
size_t &getThreadNElements() {
70-
auto &ThData = getThreadData();
71-
return ThData.NElements;
150+
void setSize(size_t Size) {
151+
size_t &SizeRef = getThreadSize();
152+
SizeRef = Size;
72153
}
73154

74155
public:
156+
// define default constructors, disable copy and move constructors.
157+
PerThreadTable() = default;
158+
PerThreadTable(const PerThreadTable &) = delete;
159+
PerThreadTable(PerThreadTable &&) = delete;
160+
PerThreadTable &operator=(const PerThreadTable &) = delete;
161+
PerThreadTable &operator=(PerThreadTable &&) = delete;
162+
~PerThreadTable() {
163+
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
164+
"Cannot be deleted while other threads are adding entries");
165+
ThreadDataList.clear();
166+
}
167+
75168
void add(ObjectType obj) {
76-
auto &Entry = getThreadEntry();
77-
auto &NElements = getThreadNElements();
78-
NElements++;
169+
ContainerType &Entry = getThreadEntry();
170+
size_t &SizeRef = getThreadSize();
171+
SizeRef++;
79172
Entry.add(obj);
80173
}
81174

82175
iterator erase(iterator it) {
83-
auto &Entry = getThreadEntry();
84-
auto &NElements = getThreadNElements();
85-
NElements--;
176+
ContainerType &Entry = getThreadEntry();
177+
size_t &SizeRef = getThreadSize();
178+
SizeRef--;
86179
return Entry.erase(it);
87180
}
88181

89-
size_t size() { return getThreadNElements(); }
182+
size_t size() { return getThreadSize(); }
90183

91184
// Iterators to traverse objects owned by
92-
// the current thread
185+
// the current thread.
93186
iterator begin() {
94-
auto &Entry = getThreadEntry();
187+
ContainerType &Entry = getThreadEntry();
95188
return Entry.begin();
96189
}
97190
iterator end() {
98-
auto &Entry = getThreadEntry();
191+
ContainerType &Entry = getThreadEntry();
99192
return Entry.end();
100193
}
101194

102-
template <class F> void clear(F f) {
103-
std::lock_guard<std::mutex> Lock(Mtx);
104-
for (auto ThData : ThreadDataList) {
105-
if (!ThData->ThEntry || ThData->NElements == 0)
195+
template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) {
196+
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
197+
"Clear cannot be called while other threads are adding entries");
198+
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
199+
if (!ThreadData->ThreadEntry || ThreadData->Size == 0)
106200
continue;
107-
ThData->ThEntry->clear(f);
108-
ThData->NElements = 0;
201+
if constexpr (ContainerConcepts<ContainerType>::hasIterator &&
202+
ContainerConcepts<ContainerType>::hasClear) {
203+
for (auto &Obj : *ThreadData->ThreadEntry) {
204+
if constexpr (ContainerConcepts<ContainerType>::isAssociative) {
205+
ClearFunc(Obj.second);
206+
} else {
207+
ClearFunc(Obj);
208+
}
209+
}
210+
ThreadData->ThreadEntry->clear();
211+
} else {
212+
static_assert(true, "Container type not supported");
213+
}
214+
ThreadData->Size = 0;
109215
}
110216
ThreadDataList.clear();
111217
}
218+
219+
template <class DeinitFuncTy> llvm::Error deinit(DeinitFuncTy DeinitFunc) {
220+
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
221+
"Deinit cannot be called while other threads are adding entries");
222+
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
223+
if (!ThreadData->ThreadEntry || ThreadData->Size == 0)
224+
continue;
225+
for (auto &Obj : *ThreadData->ThreadEntry) {
226+
if constexpr (ContainerConcepts<ContainerType>::isAssociative) {
227+
if (auto Err = DeinitFunc(Obj.second))
228+
return Err;
229+
} else {
230+
if (auto Err = DeinitFunc(Obj))
231+
return Err;
232+
}
233+
}
234+
}
235+
return llvm::Error::success();
236+
}
237+
};
238+
239+
template <typename ContainerType, size_t ReserveSize = 0>
240+
class PerThreadContainer
241+
: public PerThreadTable<ContainerType, typename ContainerConcepts<
242+
ContainerType>::value_type> {
243+
244+
using IndexType = typename ContainerConcepts<ContainerType>::key_type;
245+
using ObjectType = typename ContainerConcepts<ContainerType>::value_type;
246+
247+
public:
248+
// Get the object for the given index in the current thread.
249+
ObjectType &get(IndexType Index) {
250+
ContainerType &Entry = this->getThreadEntry();
251+
252+
// Specialized code for vector-like containers.
253+
if constexpr (ContainerConcepts<ContainerType>::hasResize) {
254+
if (Index >= Entry.size()) {
255+
if constexpr (ContainerConcepts<ContainerType>::hasReserve &&
256+
ReserveSize > 0)
257+
Entry.reserve(ReserveSize);
258+
259+
// If the index is out of bounds, try resize the container.
260+
Entry.resize(Index + 1);
261+
}
262+
}
263+
ObjectType &Ret = Entry[Index];
264+
this->setSize(Entry.size());
265+
return Ret;
266+
}
112267
};
113268

114269
#endif

0 commit comments

Comments
 (0)