|
16 | 16 | #include <list> |
17 | 17 | #include <memory> |
18 | 18 | #include <mutex> |
| 19 | +#include <type_traits> |
| 20 | + |
| 21 | +template <typename ObjectType> struct PerThread { |
| 22 | + struct PerThreadData { |
| 23 | + std::unique_ptr<ObjectType> ThEntry; |
| 24 | + }; |
| 25 | + |
| 26 | + std::mutex Mtx; |
| 27 | + std::list<std::shared_ptr<PerThreadData>> ThreadDataList; |
| 28 | + |
| 29 | + // define default constructors, disable copy and move constructors |
| 30 | + PerThread() = default; |
| 31 | + PerThread(const PerThread &) = delete; |
| 32 | + PerThread(PerThread &&) = delete; |
| 33 | + PerThread &operator=(const PerThread &) = delete; |
| 34 | + PerThread &operator=(PerThread &&) = delete; |
| 35 | + ~PerThread() { |
| 36 | + std::lock_guard<std::mutex> Lock(Mtx); |
| 37 | + ThreadDataList.clear(); |
| 38 | + } |
| 39 | + |
| 40 | +private: |
| 41 | + PerThreadData &getThreadData() { |
| 42 | + static thread_local std::shared_ptr<PerThreadData> ThData = nullptr; |
| 43 | + if (!ThData) { |
| 44 | + ThData = std::make_shared<PerThreadData>(); |
| 45 | + std::lock_guard<std::mutex> Lock(Mtx); |
| 46 | + ThreadDataList.push_back(ThData); |
| 47 | + } |
| 48 | + return *ThData; |
| 49 | + } |
| 50 | + |
| 51 | +protected: |
| 52 | + ObjectType &getThreadEntry() { |
| 53 | + auto &ThData = getThreadData(); |
| 54 | + if (ThData.ThEntry) |
| 55 | + return *ThData.ThEntry; |
| 56 | + ThData.ThEntry = std::make_unique<ObjectType>(); |
| 57 | + return *ThData.ThEntry; |
| 58 | + } |
| 59 | + |
| 60 | +public: |
| 61 | + ObjectType &get() { return getThreadEntry(); } |
| 62 | + |
| 63 | + template <class F> void clear(F f) { |
| 64 | + std::lock_guard<std::mutex> Lock(Mtx); |
| 65 | + for (auto ThData : ThreadDataList) { |
| 66 | + if (!ThData->ThEntry) |
| 67 | + continue; |
| 68 | + f(*ThData->ThEntry); |
| 69 | + } |
| 70 | + ThreadDataList.clear(); |
| 71 | + } |
| 72 | +}; |
19 | 73 |
|
20 | 74 | // Using an STL container (such as std::vector) indexed by thread ID has |
21 | 75 | // too many race conditions issues so we store each thread entry into a |
22 | 76 | // thread_local variable. |
23 | 77 | // T is the container type used to store the objects, e.g., std::vector, |
24 | 78 | // std::set, etc. by each thread. O is the type of the stored objects e.g., |
25 | 79 | // omp_interop_val_t *, ... |
26 | | - |
27 | 80 | template <typename ContainerType, typename ObjectType> struct PerThreadTable { |
28 | 81 | using iterator = typename ContainerType::iterator; |
29 | 82 |
|
| 83 | + template <typename, typename = std::void_t<>> |
| 84 | + struct has_iterator : std::false_type {}; |
| 85 | + template <typename T> |
| 86 | + struct has_iterator<T, std::void_t<typename T::iterator>> : std::true_type {}; |
| 87 | + |
| 88 | + template <typename T, typename = std::void_t<>> |
| 89 | + struct has_clear : std::false_type {}; |
| 90 | + template <typename T> |
| 91 | + struct has_clear<T, std::void_t<decltype(std::declval<T>().clear())>> |
| 92 | + : std::true_type {}; |
| 93 | + |
| 94 | + template <typename T, typename = std::void_t<>> |
| 95 | + struct has_clearAll : std::false_type {}; |
| 96 | + template <typename T> |
| 97 | + struct has_clearAll<T, std::void_t<decltype(std::declval<T>().clearAll(1))>> |
| 98 | + : std::true_type {}; |
| 99 | + |
| 100 | + template <typename, typename = std::void_t<>> |
| 101 | + struct is_associative : std::false_type {}; |
| 102 | + template <typename T> |
| 103 | + struct is_associative<T, std::void_t<typename T::mapped_type>> |
| 104 | + : std::true_type {}; |
| 105 | + |
30 | 106 | struct PerThreadData { |
31 | 107 | size_t NElements = 0; |
32 | 108 | std::unique_ptr<ContainerType> ThEntry; |
@@ -71,6 +147,11 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable { |
71 | 147 | return ThData.NElements; |
72 | 148 | } |
73 | 149 |
|
| 150 | + void setNElements(size_t Size) { |
| 151 | + auto &NElements = getThreadNElements(); |
| 152 | + NElements = Size; |
| 153 | + } |
| 154 | + |
74 | 155 | public: |
75 | 156 | void add(ObjectType obj) { |
76 | 157 | auto &Entry = getThreadEntry(); |
@@ -104,11 +185,81 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable { |
104 | 185 | for (auto ThData : ThreadDataList) { |
105 | 186 | if (!ThData->ThEntry || ThData->NElements == 0) |
106 | 187 | continue; |
107 | | - ThData->ThEntry->clear(f); |
| 188 | + if constexpr (has_clearAll<ContainerType>::value) { |
| 189 | + ThData->ThEntry->clearAll(f); |
| 190 | + } else if constexpr (has_iterator<ContainerType>::value && |
| 191 | + has_clear<ContainerType>::value) { |
| 192 | + for (auto &Obj : *ThData->ThEntry) { |
| 193 | + if constexpr (is_associative<ContainerType>::value) { |
| 194 | + f(Obj.second); |
| 195 | + } else { |
| 196 | + f(Obj); |
| 197 | + } |
| 198 | + } |
| 199 | + ThData->ThEntry->clear(); |
| 200 | + } else { |
| 201 | + static_assert(true, "Container type not supported"); |
| 202 | + } |
108 | 203 | ThData->NElements = 0; |
109 | 204 | } |
110 | 205 | ThreadDataList.clear(); |
111 | 206 | } |
112 | 207 | }; |
113 | 208 |
|
| 209 | +template <typename T, typename = std::void_t<>> struct ContainerValueType { |
| 210 | + using type = typename T::value_type; |
| 211 | +}; |
| 212 | +template <typename T> |
| 213 | +struct ContainerValueType<T, std::void_t<typename T::mapped_type>> { |
| 214 | + using type = typename T::mapped_type; |
| 215 | +}; |
| 216 | + |
| 217 | +template <typename ContainerType, size_t reserveSize = 0> |
| 218 | +struct PerThreadContainer |
| 219 | + : public PerThreadTable<ContainerType, |
| 220 | + typename ContainerValueType<ContainerType>::type> { |
| 221 | + |
| 222 | + // helpers |
| 223 | + template <typename T, typename = std::void_t<>> struct indexType { |
| 224 | + using type = typename T::size_type; |
| 225 | + }; |
| 226 | + template <typename T> struct indexType<T, std::void_t<typename T::key_type>> { |
| 227 | + using type = typename T::key_type; |
| 228 | + }; |
| 229 | + template <typename T, typename = std::void_t<>> |
| 230 | + struct has_resize : std::false_type {}; |
| 231 | + template <typename T> |
| 232 | + struct has_resize<T, std::void_t<decltype(std::declval<T>().resize(1))>> |
| 233 | + : std::true_type {}; |
| 234 | + |
| 235 | + template <typename T, typename = std::void_t<>> |
| 236 | + struct has_reserve : std::false_type {}; |
| 237 | + template <typename T> |
| 238 | + struct has_reserve<T, std::void_t<decltype(std::declval<T>().reserve(1))>> |
| 239 | + : std::true_type {}; |
| 240 | + |
| 241 | + using IndexType = typename indexType<ContainerType>::type; |
| 242 | + using ObjectType = typename ContainerValueType<ContainerType>::type; |
| 243 | + |
| 244 | + // Get the object for the given index in the current thread |
| 245 | + ObjectType &get(IndexType Index) { |
| 246 | + auto &Entry = this->getThreadEntry(); |
| 247 | + |
| 248 | + // specialized code for vector-like containers |
| 249 | + if constexpr (has_resize<ContainerType>::value) { |
| 250 | + if (Index >= Entry.size()) { |
| 251 | + if constexpr (has_reserve<ContainerType>::value && reserveSize > 0) { |
| 252 | + if (Entry.capacity() < reserveSize) |
| 253 | + Entry.reserve(reserveSize); |
| 254 | + } |
| 255 | + // If the index is out of bounds, try resize the container |
| 256 | + Entry.resize(Index + 1); |
| 257 | + } |
| 258 | + } |
| 259 | + ObjectType &Ret = Entry[Index]; |
| 260 | + this->setNElements(Entry.size()); |
| 261 | + return Ret; |
| 262 | + } |
| 263 | +}; |
| 264 | + |
114 | 265 | #endif |
0 commit comments