Skip to content

Commit 579264a

Browse files
committed
minor refactoring
1 parent a14ee62 commit 579264a

File tree

1 file changed

+35
-39
lines changed

1 file changed

+35
-39
lines changed

offload/include/PerThreadTable.h

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,14 @@
2020
#include <mutex>
2121
#include <type_traits>
2222

23-
template <typename ObjectType> struct PerThread {
23+
template <typename ObjectType> class PerThread {
2424
struct PerThreadData {
2525
std::unique_ptr<ObjectType> ThreadEntry;
2626
};
2727

2828
std::mutex Mutex;
2929
llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList;
3030

31-
// define default constructors, disable copy and move constructors
32-
PerThread() = default;
33-
PerThread(const PerThread &) = delete;
34-
PerThread(PerThread &&) = delete;
35-
PerThread &operator=(const PerThread &) = delete;
36-
PerThread &operator=(PerThread &&) = delete;
37-
~PerThread() {
38-
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
39-
"Cannot be deleted while other threads are adding entries");
40-
ThreadDataList.clear();
41-
}
42-
43-
private:
4431
PerThreadData &getThreadData() {
4532
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
4633
if (!ThreadData) {
@@ -51,7 +38,6 @@ template <typename ObjectType> struct PerThread {
5138
return *ThreadData;
5239
}
5340

54-
protected:
5541
ObjectType &getThreadEntry() {
5642
PerThreadData &ThreadData = getThreadData();
5743
if (ThreadData.ThreadEntry)
@@ -61,6 +47,18 @@ template <typename ObjectType> struct PerThread {
6147
}
6248

6349
public:
50+
// define default constructors, disable copy and move constructors
51+
PerThread() = default;
52+
PerThread(const PerThread &) = delete;
53+
PerThread(PerThread &&) = delete;
54+
PerThread &operator=(const PerThread &) = delete;
55+
PerThread &operator=(PerThread &&) = delete;
56+
~PerThread() {
57+
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
58+
"Cannot be deleted while other threads are adding entries");
59+
ThreadDataList.clear();
60+
}
61+
6462
ObjectType &get() { return getThreadEntry(); }
6563

6664
template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) {
@@ -78,10 +76,10 @@ template <typename ObjectType> struct PerThread {
7876
// Using an STL container (such as std::vector) indexed by thread ID has
7977
// too many race conditions issues so we store each thread entry into a
8078
// thread_local variable.
81-
// T is the container type used to store the objects, e.g., std::vector,
82-
// std::set, etc. by each thread. O is the type of the stored objects e.g.,
83-
// omp_interop_val_t *, ...
84-
template <typename ContainerType, typename ObjectType> struct PerThreadTable {
79+
// ContainerType is the container type used to store the objects, e.g.,
80+
// std::vector, std::set, etc. by each thread. ObjectType is the type of the
81+
// stored objects e.g., omp_interop_val_t *, ...
82+
template <typename ContainerType, typename ObjectType> class PerThreadTable {
8583
using iterator = typename ContainerType::iterator;
8684

8785
template <typename, typename = std::void_t<>>
@@ -115,19 +113,6 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
115113
std::mutex Mutex;
116114
llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList;
117115

118-
// define default constructors, disable copy and move constructors
119-
PerThreadTable() = default;
120-
PerThreadTable(const PerThreadTable &) = delete;
121-
PerThreadTable(PerThreadTable &&) = delete;
122-
PerThreadTable &operator=(const PerThreadTable &) = delete;
123-
PerThreadTable &operator=(PerThreadTable &&) = delete;
124-
~PerThreadTable() {
125-
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
126-
"Cannot be deleted while other threads are adding entries");
127-
ThreadDataList.clear();
128-
}
129-
130-
private:
131116
PerThreadData &getThreadData() {
132117
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
133118
if (!ThreadData) {
@@ -158,6 +143,18 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
158143
}
159144

160145
public:
146+
// define default constructors, disable copy and move constructors
147+
PerThreadTable() = default;
148+
PerThreadTable(const PerThreadTable &) = delete;
149+
PerThreadTable(PerThreadTable &&) = delete;
150+
PerThreadTable &operator=(const PerThreadTable &) = delete;
151+
PerThreadTable &operator=(PerThreadTable &&) = delete;
152+
~PerThreadTable() {
153+
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
154+
"Cannot be deleted while other threads are adding entries");
155+
ThreadDataList.clear();
156+
}
157+
161158
void add(ObjectType obj) {
162159
ContainerType &Entry = getThreadEntry();
163160
size_t &NElements = getThreadNElements();
@@ -239,8 +236,8 @@ struct ContainerValueType<T, std::void_t<typename T::mapped_type>> {
239236
using type = typename T::mapped_type;
240237
};
241238

242-
template <typename ContainerType, size_t reserveSize = 0>
243-
struct PerThreadContainer
239+
template <typename ContainerType, size_t ReserveSize = 0>
240+
class PerThreadContainer
244241
: public PerThreadTable<ContainerType,
245242
typename ContainerValueType<ContainerType>::type> {
246243

@@ -265,18 +262,17 @@ struct PerThreadContainer
265262

266263
using IndexType = typename indexType<ContainerType>::type;
267264
using ObjectType = typename ContainerValueType<ContainerType>::type;
268-
265+
public:
269266
// Get the object for the given index in the current thread
270267
ObjectType &get(IndexType Index) {
271268
ContainerType &Entry = this->getThreadEntry();
272269

273270
// specialized code for vector-like containers
274271
if constexpr (has_resize<ContainerType>::value) {
275272
if (Index >= Entry.size()) {
276-
if constexpr (has_reserve<ContainerType>::value && reserveSize > 0) {
277-
if (Entry.capacity() < reserveSize)
278-
Entry.reserve(reserveSize);
279-
}
273+
if constexpr (has_reserve<ContainerType>::value && ReserveSize > 0)
274+
Entry.reserve(ReserveSize);
275+
280276
// If the index is out of bounds, try resize the container
281277
Entry.resize(Index + 1);
282278
}

0 commit comments

Comments
 (0)