Skip to content

Commit 9c602be

Browse files
committed
Merge branch 'extend_ptt' into l0_plugin
2 parents 0f59337 + 2d4f364 commit 9c602be

File tree

1 file changed

+95
-116
lines changed

1 file changed

+95
-116
lines changed

offload/include/PerThreadTable.h

Lines changed: 95 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,21 @@
2020
#include <mutex>
2121
#include <type_traits>
2222

23-
template <typename ObjectType> struct PerThread {
24-
struct PerThreadData {
25-
std::unique_ptr<ObjectType> ThreadEntry;
26-
};
27-
23+
template <typename ObjectType> class PerThread {
2824
std::mutex Mutex;
29-
llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList;
25+
llvm::SmallVector<std::shared_ptr<ObjectType>> ThreadDataList;
26+
27+
ObjectType &getThreadData() {
28+
static thread_local std::shared_ptr<ObjectType> ThreadData = nullptr;
29+
if (!ThreadData) {
30+
ThreadData = std::make_shared<ObjectType>();
31+
std::lock_guard<std::mutex> Lock(Mutex);
32+
ThreadDataList.push_back(ThreadData);
33+
}
34+
return *ThreadData;
35+
}
3036

37+
public:
3138
// define default constructors, disable copy and move constructors
3239
PerThread() = default;
3340
PerThread(const PerThread &) = delete;
@@ -40,72 +47,73 @@ template <typename ObjectType> struct PerThread {
4047
ThreadDataList.clear();
4148
}
4249

43-
private:
44-
PerThreadData &getThreadData() {
45-
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
46-
if (!ThreadData) {
47-
ThreadData = std::make_shared<PerThreadData>();
48-
std::lock_guard<std::mutex> Lock(Mutex);
49-
ThreadDataList.push_back(ThreadData);
50-
}
51-
return *ThreadData;
52-
}
53-
54-
protected:
55-
ObjectType &getThreadEntry() {
56-
PerThreadData &ThreadData = getThreadData();
57-
if (ThreadData.ThreadEntry)
58-
return *ThreadData.ThreadEntry;
59-
ThreadData.ThreadEntry = std::make_unique<ObjectType>();
60-
return *ThreadData.ThreadEntry;
61-
}
62-
63-
public:
64-
ObjectType &get() { return getThreadEntry(); }
50+
ObjectType &get() { return getThreadData(); }
6551

6652
template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) {
6753
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
6854
"Clear cannot be called while other threads are adding entries");
69-
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
70-
if (!ThreadData->ThreadEntry)
55+
for (std::shared_ptr<ObjectType> ThreadData : ThreadDataList) {
56+
if (!ThreadData)
7157
continue;
72-
ClearFunc(*ThreadData->ThreadEntry);
58+
ClearFunc(*ThreadData);
7359
}
7460
ThreadDataList.clear();
7561
}
7662
};
7763

64+
template <typename ContainerTy> struct ContainerConcepts {
65+
template <typename, template <typename> class, typename = std::void_t<>>
66+
struct has : std::false_type {};
67+
template <typename Ty, template <typename> class Op>
68+
struct has<Ty, Op, std::void_t<Op<Ty>>> : std::true_type {};
69+
70+
template <typename Ty> using IteratorTypeCheck = typename Ty::iterator;
71+
template <typename Ty> using MappedTypeCheck = typename Ty::mapped_type;
72+
template <typename Ty> using ValueTypeCheck = typename Ty::value_type;
73+
template <typename Ty> using KeyTypeCheck = typename Ty::key_type;
74+
template <typename Ty> using SizeTyCheck = typename Ty::size_type;
75+
76+
template <typename Ty>
77+
using ClearCheck = decltype(std::declval<Ty>().clear());
78+
template <typename Ty>
79+
using ReserveCheck = decltype(std::declval<Ty>().reserve(1));
80+
template <typename Ty>
81+
using ResizeCheck = decltype(std::declval<Ty>().resize(1));
82+
83+
static constexpr bool hasIterator =
84+
has<ContainerTy, IteratorTypeCheck>::value;
85+
static constexpr bool hasClear = has<ContainerTy, ClearCheck>::value;
86+
static constexpr bool isAssociative =
87+
has<ContainerTy, MappedTypeCheck>::value;
88+
static constexpr bool hasReserve = has<ContainerTy, ReserveCheck>::value;
89+
static constexpr bool hasResize = has<ContainerTy, ResizeCheck>::value;
90+
91+
template <typename, template <typename> class, typename = std::void_t<>>
92+
struct has_type {
93+
using type = void;
94+
};
95+
template <typename Ty, template <typename> class Op>
96+
struct has_type<Ty, Op, std::void_t<Op<Ty>>> {
97+
using type = Op<Ty>;
98+
};
99+
100+
using iterator = typename has_type<ContainerTy, IteratorTypeCheck>::type;
101+
using value_type = typename std::conditional_t<
102+
isAssociative, typename has_type<ContainerTy, MappedTypeCheck>::type,
103+
typename has_type<ContainerTy, ValueTypeCheck>::type>;
104+
using key_type = typename std::conditional_t<
105+
isAssociative, typename has_type<ContainerTy, KeyTypeCheck>::type,
106+
typename has_type<ContainerTy, SizeTyCheck>::type>;
107+
};
108+
78109
// Using an STL container (such as std::vector) indexed by thread ID has
79110
// too many race conditions issues so we store each thread entry into a
80111
// thread_local variable.
81-
// T is the container type used to store the objects, e.g., std::vector,
82-
// std::set, etc. by each thread. O is the type of the stored objects e.g.,
83-
// omp_interop_val_t *, ...
84-
template <typename ContainerType, typename ObjectType> struct PerThreadTable {
85-
using iterator = typename ContainerType::iterator;
86-
87-
template <typename, typename = std::void_t<>>
88-
struct has_iterator : std::false_type {};
89-
template <typename T>
90-
struct has_iterator<T, std::void_t<typename T::iterator>> : std::true_type {};
91-
92-
template <typename T, typename = std::void_t<>>
93-
struct has_clear : std::false_type {};
94-
template <typename T>
95-
struct has_clear<T, std::void_t<decltype(std::declval<T>().clear())>>
96-
: std::true_type {};
97-
98-
template <typename T, typename = std::void_t<>>
99-
struct has_clearAll : std::false_type {};
100-
template <typename T>
101-
struct has_clearAll<T, std::void_t<decltype(std::declval<T>().clearAll(1))>>
102-
: std::true_type {};
103-
104-
template <typename, typename = std::void_t<>>
105-
struct is_associative : std::false_type {};
106-
template <typename T>
107-
struct is_associative<T, std::void_t<typename T::mapped_type>>
108-
: std::true_type {};
112+
// ContainerType is the container type used to store the objects, e.g.,
113+
// std::vector, std::set, etc. by each thread. ObjectType is the type of the
114+
// stored objects e.g., omp_interop_val_t *, ...
115+
template <typename ContainerType, typename ObjectType> class PerThreadTable {
116+
using iterator = typename ContainerConcepts<ContainerType>::iterator;
109117

110118
struct PerThreadData {
111119
size_t NElements = 0;
@@ -115,19 +123,6 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
115123
std::mutex Mutex;
116124
llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList;
117125

118-
// define default constructors, disable copy and move constructors
119-
PerThreadTable() = default;
120-
PerThreadTable(const PerThreadTable &) = delete;
121-
PerThreadTable(PerThreadTable &&) = delete;
122-
PerThreadTable &operator=(const PerThreadTable &) = delete;
123-
PerThreadTable &operator=(PerThreadTable &&) = delete;
124-
~PerThreadTable() {
125-
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
126-
"Cannot be deleted while other threads are adding entries");
127-
ThreadDataList.clear();
128-
}
129-
130-
private:
131126
PerThreadData &getThreadData() {
132127
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
133128
if (!ThreadData) {
@@ -158,6 +153,18 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
158153
}
159154

160155
public:
156+
// define default constructors, disable copy and move constructors
157+
PerThreadTable() = default;
158+
PerThreadTable(const PerThreadTable &) = delete;
159+
PerThreadTable(PerThreadTable &&) = delete;
160+
PerThreadTable &operator=(const PerThreadTable &) = delete;
161+
PerThreadTable &operator=(PerThreadTable &&) = delete;
162+
~PerThreadTable() {
163+
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
164+
"Cannot be deleted while other threads are adding entries");
165+
ThreadDataList.clear();
166+
}
167+
161168
void add(ObjectType obj) {
162169
ContainerType &Entry = getThreadEntry();
163170
size_t &NElements = getThreadNElements();
@@ -191,12 +198,10 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
191198
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
192199
if (!ThreadData->ThreadEntry || ThreadData->NElements == 0)
193200
continue;
194-
if constexpr (has_clearAll<ContainerType>::value) {
195-
ThreadData->ThreadEntry->clearAll(ClearFunc);
196-
} else if constexpr (has_iterator<ContainerType>::value &&
197-
has_clear<ContainerType>::value) {
201+
if constexpr (ContainerConcepts<ContainerType>::hasIterator &&
202+
ContainerConcepts<ContainerType>::hasClear) {
198203
for (auto &Obj : *ThreadData->ThreadEntry) {
199-
if constexpr (is_associative<ContainerType>::value) {
204+
if constexpr (ContainerConcepts<ContainerType>::isAssociative) {
200205
ClearFunc(Obj.second);
201206
} else {
202207
ClearFunc(Obj);
@@ -218,7 +223,7 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
218223
if (!ThreadData->ThreadEntry || ThreadData->NElements == 0)
219224
continue;
220225
for (auto &Obj : *ThreadData->ThreadEntry) {
221-
if constexpr (is_associative<ContainerType>::value) {
226+
if constexpr (ContainerConcepts<ContainerType>::isAssociative) {
222227
if (auto Err = DeinitFunc(Obj.second))
223228
return Err;
224229
} else {
@@ -231,52 +236,26 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
231236
}
232237
};
233238

234-
template <typename T, typename = std::void_t<>> struct ContainerValueType {
235-
using type = typename T::value_type;
236-
};
237-
template <typename T>
238-
struct ContainerValueType<T, std::void_t<typename T::mapped_type>> {
239-
using type = typename T::mapped_type;
240-
};
239+
template <typename ContainerType, size_t ReserveSize = 0>
240+
class PerThreadContainer
241+
: public PerThreadTable<ContainerType, typename ContainerConcepts<
242+
ContainerType>::value_type> {
241243

242-
template <typename ContainerType, size_t reserveSize = 0>
243-
struct PerThreadContainer
244-
: public PerThreadTable<ContainerType,
245-
typename ContainerValueType<ContainerType>::type> {
246-
247-
// helpers
248-
template <typename T, typename = std::void_t<>> struct indexType {
249-
using type = typename T::size_type;
250-
};
251-
template <typename T> struct indexType<T, std::void_t<typename T::key_type>> {
252-
using type = typename T::key_type;
253-
};
254-
template <typename T, typename = std::void_t<>>
255-
struct has_resize : std::false_type {};
256-
template <typename T>
257-
struct has_resize<T, std::void_t<decltype(std::declval<T>().resize(1))>>
258-
: std::true_type {};
259-
260-
template <typename T, typename = std::void_t<>>
261-
struct has_reserve : std::false_type {};
262-
template <typename T>
263-
struct has_reserve<T, std::void_t<decltype(std::declval<T>().reserve(1))>>
264-
: std::true_type {};
265-
266-
using IndexType = typename indexType<ContainerType>::type;
267-
using ObjectType = typename ContainerValueType<ContainerType>::type;
244+
using IndexType = typename ContainerConcepts<ContainerType>::key_type;
245+
using ObjectType = typename ContainerConcepts<ContainerType>::value_type;
268246

247+
public:
269248
// Get the object for the given index in the current thread
270249
ObjectType &get(IndexType Index) {
271250
ContainerType &Entry = this->getThreadEntry();
272251

273252
// specialized code for vector-like containers
274-
if constexpr (has_resize<ContainerType>::value) {
253+
if constexpr (ContainerConcepts<ContainerType>::hasResize) {
275254
if (Index >= Entry.size()) {
276-
if constexpr (has_reserve<ContainerType>::value && reserveSize > 0) {
277-
if (Entry.capacity() < reserveSize)
278-
Entry.reserve(reserveSize);
279-
}
255+
if constexpr (ContainerConcepts<ContainerType>::hasReserve &&
256+
ReserveSize > 0)
257+
Entry.reserve(ReserveSize);
258+
280259
// If the index is out of bounds, try resize the container
281260
Entry.resize(Index + 1);
282261
}

0 commit comments

Comments
 (0)