|
14 | 14 | #define OFFLOAD_PERTHREADTABLE_H |
15 | 15 |
|
16 | 16 | #include <list> |
| 17 | +#include <llvm/ADT/SmallVector.h> |
| 18 | +#include <llvm/Support/Error.h> |
17 | 19 | #include <memory> |
18 | 20 | #include <mutex> |
| 21 | +#include <type_traits> |
| 22 | + |
| 23 | +template <typename ObjectType> class PerThread { |
| 24 | + std::mutex Mutex; |
| 25 | + llvm::SmallVector<std::shared_ptr<ObjectType>> ThreadDataList; |
| 26 | + |
| 27 | + ObjectType &getThreadData() { |
| 28 | + static thread_local std::shared_ptr<ObjectType> ThreadData = nullptr; |
| 29 | + if (!ThreadData) { |
| 30 | + ThreadData = std::make_shared<ObjectType>(); |
| 31 | + std::lock_guard<std::mutex> Lock(Mutex); |
| 32 | + ThreadDataList.push_back(ThreadData); |
| 33 | + } |
| 34 | + return *ThreadData; |
| 35 | + } |
| 36 | + |
| 37 | +public: |
| 38 | + // Define default constructors, disable copy and move constructors. |
| 39 | + PerThread() = default; |
| 40 | + PerThread(const PerThread &) = delete; |
| 41 | + PerThread(PerThread &&) = delete; |
| 42 | + PerThread &operator=(const PerThread &) = delete; |
| 43 | + PerThread &operator=(PerThread &&) = delete; |
| 44 | + ~PerThread() { |
| 45 | + assert(Mutex.try_lock() && (Mutex.unlock(), true) && |
| 46 | + "Cannot be deleted while other threads are adding entries"); |
| 47 | + ThreadDataList.clear(); |
| 48 | + } |
| 49 | + |
| 50 | + ObjectType &get() { return getThreadData(); } |
| 51 | + |
| 52 | + template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) { |
| 53 | + assert(Mutex.try_lock() && (Mutex.unlock(), true) && |
| 54 | + "Clear cannot be called while other threads are adding entries"); |
| 55 | + for (std::shared_ptr<ObjectType> ThreadData : ThreadDataList) { |
| 56 | + if (!ThreadData) |
| 57 | + continue; |
| 58 | + ClearFunc(*ThreadData); |
| 59 | + } |
| 60 | + ThreadDataList.clear(); |
| 61 | + } |
| 62 | +}; |
| 63 | + |
| 64 | +template <typename ContainerTy> struct ContainerConcepts { |
| 65 | + template <typename, template <typename> class, typename = std::void_t<>> |
| 66 | + struct has : std::false_type {}; |
| 67 | + template <typename Ty, template <typename> class Op> |
| 68 | + struct has<Ty, Op, std::void_t<Op<Ty>>> : std::true_type {}; |
| 69 | + |
| 70 | + template <typename Ty> using IteratorTypeCheck = typename Ty::iterator; |
| 71 | + template <typename Ty> using MappedTypeCheck = typename Ty::mapped_type; |
| 72 | + template <typename Ty> using ValueTypeCheck = typename Ty::value_type; |
| 73 | + template <typename Ty> using KeyTypeCheck = typename Ty::key_type; |
| 74 | + template <typename Ty> using SizeTypeCheck = typename Ty::size_type; |
| 75 | + |
| 76 | + template <typename Ty> |
| 77 | + using ClearCheck = decltype(std::declval<Ty>().clear()); |
| 78 | + template <typename Ty> |
| 79 | + using ReserveCheck = decltype(std::declval<Ty>().reserve(1)); |
| 80 | + template <typename Ty> |
| 81 | + using ResizeCheck = decltype(std::declval<Ty>().resize(1)); |
| 82 | + |
| 83 | + static constexpr bool hasIterator = |
| 84 | + has<ContainerTy, IteratorTypeCheck>::value; |
| 85 | + static constexpr bool hasClear = has<ContainerTy, ClearCheck>::value; |
| 86 | + static constexpr bool isAssociative = |
| 87 | + has<ContainerTy, MappedTypeCheck>::value; |
| 88 | + static constexpr bool hasReserve = has<ContainerTy, ReserveCheck>::value; |
| 89 | + static constexpr bool hasResize = has<ContainerTy, ResizeCheck>::value; |
| 90 | + |
| 91 | + template <typename, template <typename> class, typename = std::void_t<>> |
| 92 | + struct has_type { |
| 93 | + using type = void; |
| 94 | + }; |
| 95 | + template <typename Ty, template <typename> class Op> |
| 96 | + struct has_type<Ty, Op, std::void_t<Op<Ty>>> { |
| 97 | + using type = Op<Ty>; |
| 98 | + }; |
| 99 | + |
| 100 | + using iterator = typename has_type<ContainerTy, IteratorTypeCheck>::type; |
| 101 | + using value_type = typename std::conditional_t< |
| 102 | + isAssociative, typename has_type<ContainerTy, MappedTypeCheck>::type, |
| 103 | + typename has_type<ContainerTy, ValueTypeCheck>::type>; |
| 104 | + using key_type = typename std::conditional_t< |
| 105 | + isAssociative, typename has_type<ContainerTy, KeyTypeCheck>::type, |
| 106 | + typename has_type<ContainerTy, SizeTypeCheck>::type>; |
| 107 | +}; |
19 | 108 |
|
20 | 109 | // Using an STL container (such as std::vector) indexed by thread ID has |
21 | 110 | // too many race conditions issues so we store each thread entry into a |
22 | 111 | // thread_local variable. |
23 | | -// T is the container type used to store the objects, e.g., std::vector, |
24 | | -// std::set, etc. by each thread. O is the type of the stored objects e.g., |
25 | | -// omp_interop_val_t *, ... |
26 | | - |
27 | | -template <typename ContainerType, typename ObjectType> struct PerThreadTable { |
28 | | - using iterator = typename ContainerType::iterator; |
| 112 | +// ContainerType is the container type used to store the objects, e.g., |
| 113 | +// std::vector, std::set, etc. by each thread. ObjectType is the type of the |
| 114 | +// stored objects e.g., omp_interop_val_t *, ... |
| 115 | +template <typename ContainerType, typename ObjectType> class PerThreadTable { |
| 116 | + using iterator = typename ContainerConcepts<ContainerType>::iterator; |
29 | 117 |
|
30 | 118 | struct PerThreadData { |
31 | | - size_t NElements = 0; |
32 | | - std::unique_ptr<ContainerType> ThEntry; |
| 119 | + size_t Size = 0; |
| 120 | + std::unique_ptr<ContainerType> ThreadEntry; |
33 | 121 | }; |
34 | 122 |
|
35 | | - std::mutex Mtx; |
36 | | - std::list<std::shared_ptr<PerThreadData>> ThreadDataList; |
37 | | - |
38 | | - // define default constructors, disable copy and move constructors |
39 | | - PerThreadTable() = default; |
40 | | - PerThreadTable(const PerThreadTable &) = delete; |
41 | | - PerThreadTable(PerThreadTable &&) = delete; |
42 | | - PerThreadTable &operator=(const PerThreadTable &) = delete; |
43 | | - PerThreadTable &operator=(PerThreadTable &&) = delete; |
44 | | - ~PerThreadTable() { |
45 | | - std::lock_guard<std::mutex> Lock(Mtx); |
46 | | - ThreadDataList.clear(); |
47 | | - } |
| 123 | + std::mutex Mutex; |
| 124 | + llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList; |
48 | 125 |
|
49 | | -private: |
50 | 126 | PerThreadData &getThreadData() { |
51 | | - static thread_local std::shared_ptr<PerThreadData> ThData = nullptr; |
52 | | - if (!ThData) { |
53 | | - ThData = std::make_shared<PerThreadData>(); |
54 | | - std::lock_guard<std::mutex> Lock(Mtx); |
55 | | - ThreadDataList.push_back(ThData); |
| 127 | + static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr; |
| 128 | + if (!ThreadData) { |
| 129 | + ThreadData = std::make_shared<PerThreadData>(); |
| 130 | + std::lock_guard<std::mutex> Lock(Mutex); |
| 131 | + ThreadDataList.push_back(ThreadData); |
56 | 132 | } |
57 | | - return *ThData; |
| 133 | + return *ThreadData; |
58 | 134 | } |
59 | 135 |
|
60 | 136 | protected: |
61 | 137 | ContainerType &getThreadEntry() { |
62 | | - auto &ThData = getThreadData(); |
63 | | - if (ThData.ThEntry) |
64 | | - return *ThData.ThEntry; |
65 | | - ThData.ThEntry = std::make_unique<ContainerType>(); |
66 | | - return *ThData.ThEntry; |
| 138 | + PerThreadData &ThreadData = getThreadData(); |
| 139 | + if (ThreadData.ThreadEntry) |
| 140 | + return *ThreadData.ThreadEntry; |
| 141 | + ThreadData.ThreadEntry = std::make_unique<ContainerType>(); |
| 142 | + return *ThreadData.ThreadEntry; |
| 143 | + } |
| 144 | + |
| 145 | + size_t &getThreadSize() { |
| 146 | + PerThreadData &ThreadData = getThreadData(); |
| 147 | + return ThreadData.Size; |
67 | 148 | } |
68 | 149 |
|
69 | | - size_t &getThreadNElements() { |
70 | | - auto &ThData = getThreadData(); |
71 | | - return ThData.NElements; |
| 150 | + void setSize(size_t Size) { |
| 151 | + size_t &SizeRef = getThreadSize(); |
| 152 | + SizeRef = Size; |
72 | 153 | } |
73 | 154 |
|
74 | 155 | public: |
| 156 | + // define default constructors, disable copy and move constructors. |
| 157 | + PerThreadTable() = default; |
| 158 | + PerThreadTable(const PerThreadTable &) = delete; |
| 159 | + PerThreadTable(PerThreadTable &&) = delete; |
| 160 | + PerThreadTable &operator=(const PerThreadTable &) = delete; |
| 161 | + PerThreadTable &operator=(PerThreadTable &&) = delete; |
| 162 | + ~PerThreadTable() { |
| 163 | + assert(Mutex.try_lock() && (Mutex.unlock(), true) && |
| 164 | + "Cannot be deleted while other threads are adding entries"); |
| 165 | + ThreadDataList.clear(); |
| 166 | + } |
| 167 | + |
75 | 168 | void add(ObjectType obj) { |
76 | | - auto &Entry = getThreadEntry(); |
77 | | - auto &NElements = getThreadNElements(); |
78 | | - NElements++; |
| 169 | + ContainerType &Entry = getThreadEntry(); |
| 170 | + size_t &SizeRef = getThreadSize(); |
| 171 | + SizeRef++; |
79 | 172 | Entry.add(obj); |
80 | 173 | } |
81 | 174 |
|
82 | 175 | iterator erase(iterator it) { |
83 | | - auto &Entry = getThreadEntry(); |
84 | | - auto &NElements = getThreadNElements(); |
85 | | - NElements--; |
| 176 | + ContainerType &Entry = getThreadEntry(); |
| 177 | + size_t &SizeRef = getThreadSize(); |
| 178 | + SizeRef--; |
86 | 179 | return Entry.erase(it); |
87 | 180 | } |
88 | 181 |
|
89 | | - size_t size() { return getThreadNElements(); } |
| 182 | + size_t size() { return getThreadSize(); } |
90 | 183 |
|
91 | 184 | // Iterators to traverse objects owned by |
92 | | - // the current thread |
| 185 | + // the current thread. |
93 | 186 | iterator begin() { |
94 | | - auto &Entry = getThreadEntry(); |
| 187 | + ContainerType &Entry = getThreadEntry(); |
95 | 188 | return Entry.begin(); |
96 | 189 | } |
97 | 190 | iterator end() { |
98 | | - auto &Entry = getThreadEntry(); |
| 191 | + ContainerType &Entry = getThreadEntry(); |
99 | 192 | return Entry.end(); |
100 | 193 | } |
101 | 194 |
|
102 | | - template <class F> void clear(F f) { |
103 | | - std::lock_guard<std::mutex> Lock(Mtx); |
104 | | - for (auto ThData : ThreadDataList) { |
105 | | - if (!ThData->ThEntry || ThData->NElements == 0) |
| 195 | + template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) { |
| 196 | + assert(Mutex.try_lock() && (Mutex.unlock(), true) && |
| 197 | + "Clear cannot be called while other threads are adding entries"); |
| 198 | + for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) { |
| 199 | + if (!ThreadData->ThreadEntry || ThreadData->Size == 0) |
106 | 200 | continue; |
107 | | - ThData->ThEntry->clear(f); |
108 | | - ThData->NElements = 0; |
| 201 | + if constexpr (ContainerConcepts<ContainerType>::hasIterator && |
| 202 | + ContainerConcepts<ContainerType>::hasClear) { |
| 203 | + for (auto &Obj : *ThreadData->ThreadEntry) { |
| 204 | + if constexpr (ContainerConcepts<ContainerType>::isAssociative) { |
| 205 | + ClearFunc(Obj.second); |
| 206 | + } else { |
| 207 | + ClearFunc(Obj); |
| 208 | + } |
| 209 | + } |
| 210 | + ThreadData->ThreadEntry->clear(); |
| 211 | + } else { |
| 212 | + static_assert(true, "Container type not supported"); |
| 213 | + } |
| 214 | + ThreadData->Size = 0; |
109 | 215 | } |
110 | 216 | ThreadDataList.clear(); |
111 | 217 | } |
| 218 | + |
| 219 | + template <class DeinitFuncTy> llvm::Error deinit(DeinitFuncTy DeinitFunc) { |
| 220 | + assert(Mutex.try_lock() && (Mutex.unlock(), true) && |
| 221 | + "Deinit cannot be called while other threads are adding entries"); |
| 222 | + for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) { |
| 223 | + if (!ThreadData->ThreadEntry || ThreadData->Size == 0) |
| 224 | + continue; |
| 225 | + for (auto &Obj : *ThreadData->ThreadEntry) { |
| 226 | + if constexpr (ContainerConcepts<ContainerType>::isAssociative) { |
| 227 | + if (auto Err = DeinitFunc(Obj.second)) |
| 228 | + return Err; |
| 229 | + } else { |
| 230 | + if (auto Err = DeinitFunc(Obj)) |
| 231 | + return Err; |
| 232 | + } |
| 233 | + } |
| 234 | + } |
| 235 | + return llvm::Error::success(); |
| 236 | + } |
| 237 | +}; |
| 238 | + |
| 239 | +template <typename ContainerType, size_t ReserveSize = 0> |
| 240 | +class PerThreadContainer |
| 241 | + : public PerThreadTable<ContainerType, typename ContainerConcepts< |
| 242 | + ContainerType>::value_type> { |
| 243 | + |
| 244 | + using IndexType = typename ContainerConcepts<ContainerType>::key_type; |
| 245 | + using ObjectType = typename ContainerConcepts<ContainerType>::value_type; |
| 246 | + |
| 247 | +public: |
| 248 | + // Get the object for the given index in the current thread. |
| 249 | + ObjectType &get(IndexType Index) { |
| 250 | + ContainerType &Entry = this->getThreadEntry(); |
| 251 | + |
| 252 | + // Specialized code for vector-like containers. |
| 253 | + if constexpr (ContainerConcepts<ContainerType>::hasResize) { |
| 254 | + if (Index >= Entry.size()) { |
| 255 | + if constexpr (ContainerConcepts<ContainerType>::hasReserve && |
| 256 | + ReserveSize > 0) |
| 257 | + Entry.reserve(ReserveSize); |
| 258 | + |
| 259 | + // If the index is out of bounds, try resize the container. |
| 260 | + Entry.resize(Index + 1); |
| 261 | + } |
| 262 | + } |
| 263 | + ObjectType &Ret = Entry[Index]; |
| 264 | + this->setSize(Entry.size()); |
| 265 | + return Ret; |
| 266 | + } |
112 | 267 | }; |
113 | 268 |
|
114 | 269 | #endif |
0 commit comments