Skip to content

Commit 2d4f364

Browse files
committed
Simplify PerThreadData
1 parent 32819df commit 2d4f364

File tree

1 file changed

+8
-20
lines changed

1 file changed

+8
-20
lines changed

offload/include/PerThreadTable.h

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,31 +21,19 @@
2121
#include <type_traits>
2222

2323
template <typename ObjectType> class PerThread {
24-
struct PerThreadData {
25-
std::unique_ptr<ObjectType> ThreadEntry;
26-
};
27-
2824
std::mutex Mutex;
29-
llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList;
25+
llvm::SmallVector<std::shared_ptr<ObjectType>> ThreadDataList;
3026

31-
PerThreadData &getThreadData() {
32-
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
27+
ObjectType &getThreadData() {
28+
static thread_local std::shared_ptr<ObjectType> ThreadData = nullptr;
3329
if (!ThreadData) {
34-
ThreadData = std::make_shared<PerThreadData>();
30+
ThreadData = std::make_shared<ObjectType>();
3531
std::lock_guard<std::mutex> Lock(Mutex);
3632
ThreadDataList.push_back(ThreadData);
3733
}
3834
return *ThreadData;
3935
}
4036

41-
ObjectType &getThreadEntry() {
42-
PerThreadData &ThreadData = getThreadData();
43-
if (ThreadData.ThreadEntry)
44-
return *ThreadData.ThreadEntry;
45-
ThreadData.ThreadEntry = std::make_unique<ObjectType>();
46-
return *ThreadData.ThreadEntry;
47-
}
48-
4937
public:
5038
// define default constructors, disable copy and move constructors
5139
PerThread() = default;
@@ -59,15 +47,15 @@ template <typename ObjectType> class PerThread {
5947
ThreadDataList.clear();
6048
}
6149

62-
ObjectType &get() { return getThreadEntry(); }
50+
ObjectType &get() { return getThreadData(); }
6351

6452
template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) {
6553
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
6654
"Clear cannot be called while other threads are adding entries");
67-
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
68-
if (!ThreadData->ThreadEntry)
55+
for (std::shared_ptr<ObjectType> ThreadData : ThreadDataList) {
56+
if (!ThreadData)
6957
continue;
70-
ClearFunc(*ThreadData->ThreadEntry);
58+
ClearFunc(*ThreadData);
7159
}
7260
ThreadDataList.clear();
7361
}

0 commit comments

Comments
 (0)