Skip to content

Commit 32b5557

Browse files
committed
Add thread Barrier unit test
1 parent b8d26ff commit 32b5557

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

paddle/utils/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_simple_unittest(test_Logging)
33
add_simple_unittest(test_Thread)
44
add_simple_unittest(test_StringUtils)
55
add_simple_unittest(test_CustomStackTrace)
6+
add_simple_unittest(test_ThreadBarrier)
67

78
add_executable(
89
test_CustomStackTracePrint
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include <set>
17+
#include <vector>
18+
#include "paddle/utils/Logging.h"
19+
#include "paddle/utils/CommandLineParser.h"
20+
#include "paddle/utils/Util.h"
21+
#include "paddle/utils/Locks.h"
22+
23+
P_DEFINE_int32(test_thread_num, 100, "testing thread number");
24+
25+
void testNormalImpl(size_t thread_num,
26+
const std::function<void(size_t,
27+
std::mutex&, std::set<std::thread::id>&,
28+
paddle::ThreadBarrier&)>& callback) {
29+
std::mutex mutex;
30+
std::set<std::thread::id> tids;
31+
paddle::ThreadBarrier barrier(thread_num);
32+
33+
std::vector<std::thread> threads;
34+
threads.reserve(thread_num);
35+
for (int32_t i = 0; i < thread_num; ++i) {
36+
threads.emplace_back([&thread_num, &mutex,
37+
&tids, &barrier, &callback]{
38+
callback(thread_num, mutex, tids, barrier);
39+
});
40+
}
41+
42+
for (auto& thread : threads) {
43+
thread.join();
44+
}
45+
}
46+
47+
TEST(ThreadBarrier, normalTest) {
48+
for (auto &thread_num : {10, 30, 50 , 100 , 300, 1000}) {
49+
testNormalImpl(thread_num,
50+
[](size_t thread_num, std::mutex& mutex,
51+
std::set<std::thread::id>& tids,
52+
paddle::ThreadBarrier& barrier){
53+
{
54+
std::lock_guard<std::mutex> guard(mutex);
55+
tids.insert(std::this_thread::get_id());
56+
}
57+
barrier.wait();
58+
// Check whether all threads reach this point or not
59+
CHECK_EQ(tids.size(), thread_num);
60+
});
61+
}
62+
}
63+
64+
int main(int argc, char** argv) {
65+
testing::InitGoogleTest(&argc, argv);
66+
paddle::initMain(argc, argv);
67+
return RUN_ALL_TESTS();
68+
}

0 commit comments

Comments
 (0)